Compare commits
1050 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c3f5431ba | ||
|
|
d98cf16a4c | ||
|
|
2c3c3ae546 | ||
|
|
905eef48e3 | ||
|
|
b31b520c7c | ||
|
|
17aee086a3 | ||
|
|
c1756e5767 | ||
|
|
2920279c64 | ||
|
|
1f0f985b01 | ||
|
|
0762c81633 | ||
|
|
28ef301ccc | ||
|
|
26c6a2950f | ||
|
|
5082876de3 | ||
|
|
e50e7ad3d5 | ||
|
|
45a4a6b6da | ||
|
|
02918b7267 | ||
|
|
6c662a36c1 | ||
|
|
b78fe3822a | ||
|
|
35eda37e83 | ||
|
|
176a8e7067 | ||
|
|
61d4f1fd4b | ||
|
|
121b68995e | ||
|
|
d11f1d8dae | ||
|
|
c0ef2b5064 | ||
|
|
2a7308363e | ||
|
|
dc0c556f96 | ||
|
|
ba2ee1c0aa | ||
|
|
0f8b550d68 | ||
|
|
ed1fc98821 | ||
|
|
fa53b468fd | ||
|
|
4e2533d320 | ||
|
|
388ae49e55 | ||
|
|
f3f347dcba | ||
|
|
655be3519c | ||
|
|
06df2940af | ||
|
|
4149549e42 | ||
|
|
da351991f8 | ||
|
|
3305152e50 | ||
|
|
bea7bae674 | ||
|
|
45773d38ed | ||
|
|
8d4c176314 | ||
|
|
9ca5c87c4c | ||
|
|
36a6f00e5f | ||
|
|
e24a5b4cb5 | ||
|
|
f88031b0c9 | ||
|
|
830151e6da | ||
|
|
1e14fba81a | ||
|
|
7b8800c4eb | ||
|
|
8f4625f53b | ||
|
|
1e5f243edb | ||
|
|
e5eab2af34 | ||
|
|
c10973e160 | ||
|
|
b1e4bff3ec | ||
|
|
c1202cda63 | ||
|
|
32d6cd7776 | ||
|
|
2f78d30e93 | ||
|
|
33407c9f0d | ||
|
|
d2d5ef1c5c | ||
|
|
98d8eaee02 | ||
|
|
10b9228060 | ||
|
|
5872f1e017 | ||
|
|
5073f21002 | ||
|
|
69aaf09ac8 | ||
|
|
6e61ee81d8 | ||
|
|
cfd05a8d17 | ||
|
|
29845fcc4c | ||
|
|
e204b180a8 | ||
|
|
563972fd29 | ||
|
|
cbe94b84fc | ||
|
|
aa6f73574d | ||
|
|
94f0419ef7 | ||
|
|
cefd2d7f49 | ||
|
|
81e1e545fb | ||
|
|
d516920e72 | ||
|
|
2171372246 | ||
|
|
d2df4d0cce | ||
|
|
6ab90fc123 | ||
|
|
1a84ebbb1e | ||
|
|
c9c0352369 | ||
|
|
9903b028a3 | ||
|
|
49def5d883 | ||
|
|
6975525b70 | ||
|
|
fbc4f8527b | ||
|
|
90cb5a1951 | ||
|
|
ac71d9f034 | ||
|
|
64bcbc9fc0 | ||
|
|
9e7d46f956 | ||
|
|
e911896cfb | ||
|
|
9c6d66093f | ||
|
|
b2e39b9701 | ||
|
|
e95ad4049b | ||
|
|
1df49d1d6f | ||
|
|
b71000e2f3 | ||
|
|
47e6ed455e | ||
|
|
92592fb9d9 | ||
|
|
02a9769b35 | ||
|
|
7640f11bfc | ||
|
|
be8a0991ed | ||
|
|
9fa44dbcfa | ||
|
|
61aac9c80c | ||
|
|
60af83cfee | ||
|
|
cf64e6c231 | ||
|
|
2cae941bae | ||
|
|
bc0784f41d | ||
|
|
b711140f26 | ||
|
|
c57d75e01a | ||
|
|
1d766001bb | ||
|
|
0759a11a85 | ||
|
|
cb749a38ab | ||
|
|
369eab18ab | ||
|
|
73edeae013 | ||
|
|
7d46314dc8 | ||
|
|
d5a53a89eb | ||
|
|
a85bc510dd | ||
|
|
2beea7d218 | ||
|
|
a93cd3dd5f | ||
|
|
6c1f540170 | ||
|
|
d026a9f009 | ||
|
|
a8e7dadd39 | ||
|
|
2f8d921adf | ||
|
|
0c6e526f94 | ||
|
|
b1e3018b6b | ||
|
|
87f05fce66 | ||
|
|
1b37530c96 | ||
|
|
db4d02c2e2 | ||
|
|
fd7811402b | ||
|
|
eb0325e627 | ||
|
|
842c3c8ea9 | ||
|
|
8b4b04ec09 | ||
|
|
9f32c9280f | ||
|
|
4fcd09cfa8 | ||
|
|
7a8d65d37d | ||
|
|
23129a9ba2 | ||
|
|
7f791e730b | ||
|
|
f7e296b349 | ||
|
|
712d4acaaa | ||
|
|
74a5c01f21 | ||
|
|
3ba8724d77 | ||
|
|
6313a7d8a9 | ||
|
|
432a3f520c | ||
|
|
191b3e42d4 | ||
|
|
a27f05fcb4 | ||
|
|
2f33e0b873 | ||
|
|
f0359467f1 | ||
|
|
d1db8cf2c8 | ||
|
|
b1985ed2ce | ||
|
|
140ddc70e6 | ||
|
|
d7fd616470 | ||
|
|
3ccbef141e | ||
|
|
e92fbb0443 | ||
|
|
bd270aed68 | ||
|
|
28d7864393 | ||
|
|
b5d8173ee3 | ||
|
|
17d62a9af7 | ||
|
|
d89fb863ed | ||
|
|
a21ad77820 | ||
|
|
f86c8e8cab | ||
|
|
cb12cbdd3d | ||
|
|
6661fa996c | ||
|
|
c19bca798b | ||
|
|
8f98b411db | ||
|
|
a8aa03847e | ||
|
|
1bfd747cc6 | ||
|
|
ae06d945a7 | ||
|
|
9f41d5f34d | ||
|
|
ef61c52908 | ||
|
|
d8842ef274 | ||
|
|
c88fdaf353 | ||
|
|
af295da871 | ||
|
|
083235a2fe | ||
|
|
2a3a5f7eb2 | ||
|
|
77c48f280f | ||
|
|
0ee1eb2f9f | ||
|
|
c2b20365bb | ||
|
|
cfdc7e4452 | ||
|
|
2363f61aa9 | ||
|
|
557ac6f9fa | ||
|
|
a49b871cf9 | ||
|
|
a0d6b3efba | ||
|
|
6cabf07bc0 | ||
|
|
a15444ee8c | ||
|
|
ceb5f5669e | ||
|
|
25b75e05e4 | ||
|
|
4d214bb5c1 | ||
|
|
7cbaed8c6c | ||
|
|
2915fdf665 | ||
|
|
a66c385b08 | ||
|
|
4dace7c5d8 | ||
|
|
8ebf087dbf | ||
|
|
2fa8bda5bb | ||
|
|
a5ae833945 | ||
|
|
d21d42b312 | ||
|
|
78575f0f0a | ||
|
|
8ccd292d16 | ||
|
|
2534f59398 | ||
|
|
5c60dbe2b1 | ||
|
|
c99ecde15f | ||
|
|
219f3403d9 | ||
|
|
00f417bad6 | ||
|
|
81649f053b | ||
|
|
e5bde50f2d | ||
|
|
0321e00b0d | ||
|
|
09528e3292 | ||
|
|
e7412a9cbf | ||
|
|
01efe5f869 | ||
|
|
28a178a55c | ||
|
|
88f130014c | ||
|
|
af258c590c | ||
|
|
b0eb5733be | ||
|
|
fe35bfba37 | ||
|
|
7cfbc4ab8f | ||
|
|
7a9d4f0abd | ||
|
|
6f6a5b565c | ||
|
|
e57deb873c | ||
|
|
0f692b1608 | ||
|
|
8c03e79f99 | ||
|
|
71290f0929 | ||
|
|
22364ef7de | ||
|
|
2cc1eb1abc | ||
|
|
90dbcbb4e2 | ||
|
|
66503d58be | ||
|
|
8e10f0ce2b | ||
|
|
f51f510f2e | ||
|
|
c44f085b47 | ||
|
|
a35f36eeaf | ||
|
|
14564c392a | ||
|
|
76e05ea749 | ||
|
|
ab599dceed | ||
|
|
4c37604445 | ||
|
|
bb74018d19 | ||
|
|
575289e5bc | ||
|
|
e89da2a7b4 | ||
|
|
bd34959f68 | ||
|
|
622dcf8fd5 | ||
|
|
9e315739b7 | ||
|
|
7b01adc5df | ||
|
|
432fc47443 | ||
|
|
d8fba44c5e | ||
|
|
e29d3d8c01 | ||
|
|
e678413214 | ||
|
|
eaa9d9d087 | ||
|
|
9e3cc076b7 | ||
|
|
3bb01fa52c | ||
|
|
008e49d144 | ||
|
|
4e275384b0 | ||
|
|
63ec99f67a | ||
|
|
14a8bb57df | ||
|
|
7512bfc710 | ||
|
|
3c3b6dadc3 | ||
|
|
cd722a0e39 | ||
|
|
a1b5d0a100 | ||
|
|
69d3ae709c | ||
|
|
67ef993d61 | ||
|
|
20f49890ad | ||
|
|
3e4917f0a1 | ||
|
|
99ee75aec6 | ||
|
|
1674653a42 | ||
|
|
d2f7e55bf5 | ||
|
|
9f31df7f3a | ||
|
|
b8c1b53d67 | ||
|
|
2495837791 | ||
|
|
b6562e3c47 | ||
|
|
c57da046ee | ||
|
|
ff63134c14 | ||
|
|
3f5210c587 | ||
|
|
3df5e7b9b9 | ||
|
|
225db66738 | ||
|
|
383ebb8f57 | ||
|
|
e1bed60f1f | ||
|
|
edbb856023 | ||
|
|
98d3ab646f | ||
|
|
81be556f1b | ||
|
|
f45a085469 | ||
|
|
210cc58cc3 | ||
|
|
1063b11ef6 | ||
|
|
a4e999c47f | ||
|
|
543e01c301 | ||
|
|
14e0aa3ec5 | ||
|
|
1a8a171f8b | ||
|
|
f1954f9a43 | ||
|
|
441b148501 | ||
|
|
bd0f30b81c | ||
|
|
ad14e9bf40 | ||
|
|
6f71301aaf | ||
|
|
5f0d601baa | ||
|
|
f234a5bcc2 | ||
|
|
ab677ea100 | ||
|
|
f3ad53e949 | ||
|
|
d324cfa84d | ||
|
|
dd4319d72a | ||
|
|
1f2de3d3d8 | ||
|
|
72702beb0b | ||
|
|
adb0cbc5dd | ||
|
|
6a503b82c3 | ||
|
|
28a87351f1 | ||
|
|
bcc97378b0 | ||
|
|
eb8a138713 | ||
|
|
dcd7dcbbdf | ||
|
|
1538759ba7 | ||
|
|
30e8ea7fd8 | ||
|
|
879b7b582c | ||
|
|
8ba4236402 | ||
|
|
5eef8fa9b9 | ||
|
|
d03d035437 | ||
|
|
68e8e1f70b | ||
|
|
7acb45b157 | ||
|
|
c36142deaf | ||
|
|
5fd6e316fa | ||
|
|
39a9d7765a | ||
|
|
7cfcba29a6 | ||
|
|
9bf8aadca9 | ||
|
|
714d4af63d | ||
|
|
8203fdb4f0 | ||
|
|
5e1e2d1a4f | ||
|
|
2f941de65b | ||
|
|
777c503002 | ||
|
|
e9b23f68fd | ||
|
|
efa45e6203 | ||
|
|
638f55f83c | ||
|
|
8b2fc29d5b | ||
|
|
b516fb0550 | ||
|
|
efef34c01e | ||
|
|
5f1dfa7599 | ||
|
|
8e9c7544cf | ||
|
|
4e3d5641c8 | ||
|
|
20b760529e | ||
|
|
a55a07c5ff | ||
|
|
94ee8ea297 | ||
|
|
ec5d71d0e1 | ||
|
|
d121d08d05 | ||
|
|
be08f4a558 | ||
|
|
010f082fbb | ||
|
|
073cdf6d51 | ||
|
|
4df8606ab6 | ||
|
|
71442d26ec | ||
|
|
4f5528869c | ||
|
|
f16feff17b | ||
|
|
71b233fe5f | ||
|
|
770dec9ed6 | ||
|
|
2ca95a988e | ||
|
|
d8aae538cd | ||
|
|
cf1e7ee08a | ||
|
|
d14513ddfd | ||
|
|
9a9017bc6c | ||
|
|
3c9b654713 | ||
|
|
80d2ad40bc | ||
|
|
31670e75e5 | ||
|
|
ed6011a2be | ||
|
|
cdded38ade | ||
|
|
f536f24833 | ||
|
|
f5bff00b1f | ||
|
|
27c9717445 | ||
|
|
863a1ba8ef | ||
|
|
cb04dd2b83 | ||
|
|
8c7cf51958 | ||
|
|
244fb1fed6 | ||
|
|
25f7a68a13 | ||
|
|
62d8cf79ef | ||
|
|
646b18d910 | ||
|
|
2f81b2e381 | ||
|
|
1f5a7e7885 | ||
|
|
80fca470f2 | ||
|
|
6e9d9ac856 | ||
|
|
8d6fada1eb | ||
|
|
3e715399a1 | ||
|
|
81cc8831f9 | ||
|
|
f7370044a7 | ||
|
|
51b015a629 | ||
|
|
392af7a553 | ||
|
|
d2dd07bad7 | ||
|
|
cebcd6925a | ||
|
|
e7b4357fc7 | ||
|
|
dc279dde4a | ||
|
|
c0810a674f | ||
|
|
0760cabbbe | ||
|
|
3b149c520b | ||
|
|
3d19fc89ff | ||
|
|
cd1b1919f4 | ||
|
|
0ed646eb27 | ||
|
|
c0c5859c99 | ||
|
|
a47121b849 | ||
|
|
d9dd20e89a | ||
|
|
ed4609ebe5 | ||
|
|
e24225c828 | ||
|
|
01ef86d658 | ||
|
|
cd4802da04 | ||
|
|
2aca65780f | ||
|
|
2c435f7387 | ||
|
|
cc1afd1a9c | ||
|
|
6f098cdba6 | ||
|
|
d03e9fb90a | ||
|
|
9f2966abe9 | ||
|
|
4e28ea1883 | ||
|
|
289214e85c | ||
|
|
a20d98bf93 | ||
|
|
7c3d98acbe | ||
|
|
7311786f48 | ||
|
|
82de9c926e | ||
|
|
7fd86d4de3 | ||
|
|
724da29e2a | ||
|
|
54113d7b94 | ||
|
|
66396e8290 | ||
|
|
72be76215f | ||
|
|
ace86703a9 | ||
|
|
7b25495463 | ||
|
|
3d4b651c1f | ||
|
|
d305ae064d | ||
|
|
ac4f3d8907 | ||
|
|
af2687771b | ||
|
|
a67b7f909a | ||
|
|
f9c3e4cdb0 | ||
|
|
dc62c1f8d4 | ||
|
|
0441b51a68 | ||
|
|
5c0c9f687e | ||
|
|
e049c54043 | ||
|
|
99e47540d5 | ||
|
|
8e1885ffeb | ||
|
|
8501a0c205 | ||
|
|
797f2a3173 | ||
|
|
1057b4bc35 | ||
|
|
efc0116595 | ||
|
|
cdc560fad0 | ||
|
|
75a2803710 | ||
|
|
fb3169faa4 | ||
|
|
d587bd837e | ||
|
|
b9fab74edc | ||
|
|
50c22bbadb | ||
|
|
d0b10b9195 | ||
|
|
50a296de20 | ||
|
|
c8fe4f4a3c | ||
|
|
a8ba0720af | ||
|
|
745a01246c | ||
|
|
bee5d3550f | ||
|
|
1789393151 | ||
|
|
345afe1338 | ||
|
|
65428aa49f | ||
|
|
b251ee9322 | ||
|
|
04f00682a0 | ||
|
|
90dcda1475 | ||
|
|
f1ee4eb89f | ||
|
|
343fc22168 | ||
|
|
00ef0d7e3d | ||
|
|
f2deaf6199 | ||
|
|
617a2c010e | ||
|
|
c79e38e044 | ||
|
|
38eae1d1ee | ||
|
|
7e4c89b0cb | ||
|
|
14c29f07bd | ||
|
|
825e3dbcf5 | ||
|
|
8275130f04 | ||
|
|
2c47abea95 | ||
|
|
85aa28d724 | ||
|
|
53a3736b04 | ||
|
|
86ba3c230e | ||
|
|
8d21126bd6 | ||
|
|
74ded91976 | ||
|
|
7c27520d57 | ||
|
|
b54bbc4c5a | ||
|
|
3e09a4ddd4 | ||
|
|
f93f04a536 | ||
|
|
b93f30b809 | ||
|
|
95bd2f26a5 | ||
|
|
7cfcf056f9 | ||
|
|
96b565e1e8 | ||
|
|
9d7ad7a18f | ||
|
|
9838c2758b | ||
|
|
1b1f5f5a5e | ||
|
|
0f95f62aa1 | ||
|
|
9405ba7871 | ||
|
|
ccb95f803c | ||
|
|
dae745d925 | ||
|
|
791db65526 | ||
|
|
60b2ff0a7a | ||
|
|
e6c8507379 | ||
|
|
420db5416e | ||
|
|
6e03218d54 | ||
|
|
5e4bd36b26 | ||
|
|
bbc039366e | ||
|
|
e1ec7dbbba | ||
|
|
075b008740 | ||
|
|
b2c382fa01 | ||
|
|
02e2e617f5 | ||
|
|
c5f9b5861f | ||
|
|
2dace4c697 | ||
|
|
c7891385ca | ||
|
|
2059ddcadf | ||
|
|
ba1b68df20 | ||
|
|
bfc8024119 | ||
|
|
f26cf6ed6f | ||
|
|
403b61836d | ||
|
|
b5af7d1eb9 | ||
|
|
f453af6e4c | ||
|
|
f2be55bd8e | ||
|
|
d241dd17ca | ||
|
|
cecafdfe6c | ||
|
|
6fecfd1a0e | ||
|
|
64245d001c | ||
|
|
7d92965cae | ||
|
|
b4fa08c4e2 | ||
|
|
d4e9566851 | ||
|
|
a26b494f7f | ||
|
|
b84e22e41f | ||
|
|
cee6efab19 | ||
|
|
30f71cb550 | ||
|
|
771e755a78 | ||
|
|
16ec462abd | ||
|
|
ca55465d3c | ||
|
|
7098c98dde | ||
|
|
f56355da89 | ||
|
|
422160debd | ||
|
|
8062cf406a | ||
|
|
0e802232ec | ||
|
|
f650a9205d | ||
|
|
c85dbb2347 | ||
|
|
a6a79128c8 | ||
|
|
42839627e8 | ||
|
|
e7f35098e4 | ||
|
|
267e68a894 | ||
|
|
b32b444438 | ||
|
|
522d0f8313 | ||
|
|
5715e5de67 | ||
|
|
cc6b05e8b3 | ||
|
|
417747d5d0 | ||
|
|
a34f439226 | ||
|
|
b7ca014fd0 | ||
|
|
fa098d585a | ||
|
|
c35a14e3ec | ||
|
|
60651736a5 | ||
|
|
581f9b7bd3 | ||
|
|
124eb04807 | ||
|
|
1d561da7fb | ||
|
|
16e3cd0784 | ||
|
|
a6d91933dc | ||
|
|
445c40f758 | ||
|
|
725a841a3b | ||
|
|
f77c453843 | ||
|
|
ba6718d5bc | ||
|
|
cdb7a1b3fa | ||
|
|
a03c79b89d | ||
|
|
98800d3426 | ||
|
|
a616adaac4 | ||
|
|
ffb5605c99 | ||
|
|
621b556856 | ||
|
|
a3ffecbb2a | ||
|
|
ea64cebe2a | ||
|
|
e79487dd5f | ||
|
|
7fe1c1ec89 | ||
|
|
ab2bbff369 | ||
|
|
ec32825309 | ||
|
|
fd0c182087 | ||
|
|
49fcff1daf | ||
|
|
33b64ddf39 | ||
|
|
4c447aa648 | ||
|
|
ccbfc3d274 | ||
|
|
f83fe43bbb | ||
|
|
19022d67f8 | ||
|
|
58a815dd6b | ||
|
|
1ce95c473d | ||
|
|
eb365e398d | ||
|
|
bc9fe82860 | ||
|
|
b3cd9bf2b9 | ||
|
|
c5c2b829ec | ||
|
|
9713f96401 | ||
|
|
11f35ebf96 | ||
|
|
7d403aa181 | ||
|
|
64af810a4a | ||
|
|
30821905af | ||
|
|
a9dbff756b | ||
|
|
a6aba10d3d | ||
|
|
9c276c37fe | ||
|
|
6ab6c0fd4c | ||
|
|
b6b0fe3fff | ||
|
|
0d5825bda9 | ||
|
|
cdfb64631a | ||
|
|
d161c281c8 | ||
|
|
8fed5bf2a1 | ||
|
|
98d2e9bd27 | ||
|
|
a03af55edd | ||
|
|
86e2fd9aee | ||
|
|
97bd0e5e58 | ||
|
|
ceaba21986 | ||
|
|
172a77d942 | ||
|
|
4f9d2d2a7d | ||
|
|
8c929f6e05 | ||
|
|
3319b71f5b | ||
|
|
46ec028a5b | ||
|
|
0ce0ef3e5c | ||
|
|
375b071cb2 | ||
|
|
29e1417ff2 | ||
|
|
75db2bd366 | ||
|
|
60ca1efbda | ||
|
|
2692e4978b | ||
|
|
91982eb002 | ||
|
|
bb1dec76fa | ||
|
|
f618b8fcdc | ||
|
|
9147cab75b | ||
|
|
5f07bcc8e6 | ||
|
|
705cf2ea1b | ||
|
|
42c4394484 | ||
|
|
221221a3c1 | ||
|
|
9564166297 | ||
|
|
f5cf3c3c8e | ||
|
|
18f919fb6b | ||
|
|
0924835253 | ||
|
|
20d2e5c578 | ||
|
|
907801605c | ||
|
|
93bc684e8c | ||
|
|
a76c98d57e | ||
|
|
d937a800d0 | ||
|
|
d16f3a227f | ||
|
|
80c9a3eeda | ||
|
|
e68173b451 | ||
|
|
40c27d87f5 | ||
|
|
3c13b5049d | ||
|
|
8288d5e51f | ||
|
|
6e1449900a | ||
|
|
4ffbb18ab4 | ||
|
|
b27271b7a3 | ||
|
|
ebb6665f64 | ||
|
|
e4e5731ffd | ||
|
|
2ab5810f13 | ||
|
|
af934c5d09 | ||
|
|
1e0cf7c112 | ||
|
|
46859c93c9 | ||
|
|
ea1f9cb3b2 | ||
|
|
1641549016 | ||
|
|
716a5dbb8a | ||
|
|
af98cb11c5 | ||
|
|
9a4c2cf341 | ||
|
|
2bc3bcd102 | ||
|
|
d6c663f79d | ||
|
|
9ed86e5f53 | ||
|
|
303e0bc037 | ||
|
|
2cc24019f9 | ||
|
|
83ce774d19 | ||
|
|
2b4ee13b5e | ||
|
|
3a964561f0 | ||
|
|
6959f86632 | ||
|
|
537d373e10 | ||
|
|
cceadf222c | ||
|
|
cf5a4af623 | ||
|
|
39aea11c22 | ||
|
|
c2f1227700 | ||
|
|
900f14d37c | ||
|
|
598249b1d6 | ||
|
|
7ed15bdf04 | ||
|
|
2fc0ec0f72 | ||
|
|
5e9c2a669b | ||
|
|
b310521884 | ||
|
|
288945bf7e | ||
|
|
4fc07cff36 | ||
|
|
b884fe0e86 | ||
|
|
855858c236 | ||
|
|
c11a2a5419 | ||
|
|
773a6572af | ||
|
|
88ad373c9b | ||
|
|
51666464b9 | ||
|
|
5af9cf2f52 | ||
|
|
12c4ae4b10 | ||
|
|
4e1bef414a | ||
|
|
e896c18644 | ||
|
|
c852685e74 | ||
|
|
1e99797df8 | ||
|
|
52a4c986a8 | ||
|
|
c501728204 | ||
|
|
6b067fa6a7 | ||
|
|
a1cd5c53a9 | ||
|
|
a46d487e03 | ||
|
|
3deb6d3ab3 | ||
|
|
af34cdd5d2 | ||
|
|
6e1393235a | ||
|
|
343e0b54b9 | ||
|
|
ecb70cb6f7 | ||
|
|
ca50618af6 | ||
|
|
29c07ba83e | ||
|
|
45fbb83a9f | ||
|
|
ae7ba2df25 | ||
|
|
c3ef57cc32 | ||
|
|
7bb4ca5a14 | ||
|
|
063783d81d | ||
|
|
42116c9b65 | ||
|
|
a36e11973d | ||
|
|
5125568ea2 | ||
|
|
0fa164e50d | ||
|
|
cf814e81ee | ||
|
|
43a45f18ce | ||
|
|
ad51381063 | ||
|
|
0b0e4ce904 | ||
|
|
6a3e04d688 | ||
|
|
4107a17370 | ||
|
|
06b4d8f169 | ||
|
|
1c0c820746 | ||
|
|
d061403a28 | ||
|
|
5c092321a6 | ||
|
|
bdd3f61c1f | ||
|
|
8023557d6e | ||
|
|
074b0ced7a | ||
|
|
3864b1ac9b | ||
|
|
6e9b43457d | ||
|
|
ca1aec8920 | ||
|
|
acac580862 | ||
|
|
673e1b2980 | ||
|
|
f62157be72 | ||
|
|
f894ecf3b6 | ||
|
|
66dd4e28ad | ||
|
|
939dc1b0fb | ||
|
|
56bf5d38a1 | ||
|
|
d09b70b295 | ||
|
|
205180387a | ||
|
|
39c8cfeda5 | ||
|
|
f38a329be5 | ||
|
|
a0cd069539 | ||
|
|
bf306a2f01 | ||
|
|
c31f93a8d1 | ||
|
|
4730ab6309 | ||
|
|
1ae78ca98c | ||
|
|
d2379da478 | ||
|
|
0f64981b20 | ||
|
|
0002e49bb5 | ||
|
|
db13a60274 | ||
|
|
db0f11a359 | ||
|
|
ac7f43520b | ||
|
|
f67b9f5f6e | ||
|
|
c75156c4ce | ||
|
|
10270b5595 | ||
|
|
f7458572ed | ||
|
|
d57b7222b2 | ||
|
|
62e70a673a | ||
|
|
5e9eba6478 | ||
|
|
cb02dfe1a4 | ||
|
|
b50739e1af | ||
|
|
8da1b0212d | ||
|
|
ca1f2acb33 | ||
|
|
c15f966669 | ||
|
|
7705b8781a | ||
|
|
b2502746f0 | ||
|
|
ab68094386 | ||
|
|
bbec701223 | ||
|
|
b29d14e600 | ||
|
|
86e51c5cd1 | ||
|
|
cb8267be3f | ||
|
|
eaed43915c | ||
|
|
bd91fd2c38 | ||
|
|
1203b214cd | ||
|
|
c3fec15f11 | ||
|
|
0545653494 | ||
|
|
db2989bdb4 | ||
|
|
587bd00a19 | ||
|
|
960ff438e8 | ||
|
|
98e7ea85d3 | ||
|
|
2549e44710 | ||
|
|
4d32b563ca | ||
|
|
3a4b732977 | ||
|
|
500909a28e | ||
|
|
07753eb25b | ||
|
|
c6eaf3d010 | ||
|
|
6723fe8271 | ||
|
|
3348b70435 | ||
|
|
35a8527c16 | ||
|
|
7afc475290 | ||
|
|
789bceaa3a | ||
|
|
abbc043969 | ||
|
|
654e5762f1 | ||
|
|
507c3e3629 | ||
|
|
991dfeb2f2 | ||
|
|
26482fc2d3 | ||
|
|
e0ce6d9688 | ||
|
|
946595216a | ||
|
|
864b6bc56d | ||
|
|
6ea5b7581f | ||
|
|
f70b8f0c10 | ||
|
|
1593bcb537 | ||
|
|
bf7fc02c8d | ||
|
|
143702b92b | ||
|
|
c5ccc1a084 | ||
|
|
2ecb52a9b2 | ||
|
|
6439917cbe | ||
|
|
d21c18f657 | ||
|
|
25ef0039e4 | ||
|
|
e6981290bc | ||
|
|
75c3d8abbd | ||
|
|
d88683f498 | ||
|
|
40b9aa3a4c | ||
|
|
b6d1515d58 | ||
|
|
e01d4264e3 | ||
|
|
2117b65487 | ||
|
|
a7823b352f | ||
|
|
c543b62a08 | ||
|
|
3923b87f08 | ||
|
|
b7ecdadb83 | ||
|
|
5ff121e1ed | ||
|
|
f486e5448f | ||
|
|
c5aae98558 | ||
|
|
6d8a3b9897 | ||
|
|
6d98780e19 | ||
|
|
3ad2c46f3f | ||
|
|
a730cee7fd | ||
|
|
77c823c100 | ||
|
|
124f21c67a | ||
|
|
e46cf20dd3 | ||
|
|
4bef5e8313 | ||
|
|
22e93b0af4 | ||
|
|
5aeca9662b | ||
|
|
b996cf1f05 | ||
|
|
878a106877 | ||
|
|
45d36f86fd | ||
|
|
b108ae403a | ||
|
|
887ed66768 | ||
|
|
dac840a887 | ||
|
|
238de4ba8c | ||
|
|
9a7bdade43 | ||
|
|
aa84556204 | ||
|
|
6b68069fcd | ||
|
|
42c7034fb2 | ||
|
|
060c7e0145 | ||
|
|
b5b085dfb1 | ||
|
|
fc06ce9d7f | ||
|
|
d8d81b05a7 | ||
|
|
a60f42b1f2 | ||
|
|
6e18be88d0 | ||
|
|
b45e439c48 | ||
|
|
b87061c18c | ||
|
|
f78aca7752 | ||
|
|
3ccca2aa10 | ||
|
|
6d7c40eb76 | ||
|
|
da4cd7fb65 | ||
|
|
c97cda6b84 | ||
|
|
7a7fd4167a | ||
|
|
dffc1a43d5 | ||
|
|
36897fea1e | ||
|
|
c7b34735f0 | ||
|
|
5b07176c88 | ||
|
|
474b40d660 | ||
|
|
a62901b948 | ||
|
|
25d8746327 | ||
|
|
aff1698223 | ||
|
|
7f8941745f | ||
|
|
b858401098 | ||
|
|
d5a158b80f | ||
|
|
f315f284aa | ||
|
|
c367f5009d | ||
|
|
6db1e63bda | ||
|
|
e22ab2ede6 | ||
|
|
b7d7e0b682 | ||
|
|
96bba15f2f | ||
|
|
fcf965a595 | ||
|
|
e1a20d3c22 | ||
|
|
2abd7d8c5d | ||
|
|
5b8f73cdd7 | ||
|
|
7fd765421f | ||
|
|
d9d94af022 | ||
|
|
790b924e57 | ||
|
|
4a62f877df | ||
|
|
ac47c57bb7 | ||
|
|
3ace4199a1 | ||
|
|
e6bd7524c1 | ||
|
|
699c86e8c1 | ||
|
|
f40fa0ecea | ||
|
|
626f94686b | ||
|
|
752d13b1b1 | ||
|
|
54c0dc1b2b | ||
|
|
c5bc709898 | ||
|
|
ccdbb01513 | ||
|
|
5206d750ac | ||
|
|
a800e3df67 | ||
|
|
ccb1f87a20 | ||
|
|
c111da4681 | ||
|
|
9cc4e97a53 | ||
|
|
dca1c0b0f3 | ||
|
|
f06be6ed21 | ||
|
|
3c8ec2f42e | ||
|
|
7e193f7f52 | ||
|
|
7069b02929 | ||
|
|
66995db927 | ||
|
|
c36054ca1b | ||
|
|
3e07fbf3dc | ||
|
|
bf3fbe3e96 | ||
|
|
0a93d22bc8 | ||
|
|
f5b3d94d16 | ||
|
|
4d1a6994aa | ||
|
|
05c686782c | ||
|
|
85609ea742 | ||
|
|
20dabc0615 | ||
|
|
356dd9bc2b | ||
|
|
cd5d7534c4 | ||
|
|
b4f12fc933 | ||
|
|
cbea387ce0 | ||
|
|
345b155374 | ||
|
|
29d216950e | ||
|
|
321b04772c | ||
|
|
5b924aee98 | ||
|
|
46d44e3405 | ||
|
|
4d5332fe25 | ||
|
|
18bd4c54f4 | ||
|
|
31c7768ca0 | ||
|
|
6ec643e9d1 | ||
|
|
2b39f6f61c | ||
|
|
bf3ca13961 | ||
|
|
82026370ec | ||
|
|
6d49bf5346 | ||
|
|
67431d87fb | ||
|
|
fdf55221e6 | ||
|
|
07f277dd3b | ||
|
|
cf8f0603ca | ||
|
|
5592408ab8 | ||
|
|
a01617b45c | ||
|
|
7abb4087b3 | ||
|
|
dff15cf27a | ||
|
|
aa858137e5 | ||
|
|
45cb143202 | ||
|
|
7a9c6ab8c4 | ||
|
|
e2c26c292d | ||
|
|
be7c3fd00e | ||
|
|
7e5461a2cf | ||
|
|
6ee9010645 | ||
|
|
a23d5be056 | ||
|
|
97a6a1fdc2 | ||
|
|
c8f567347b | ||
|
|
74c1e7f69e | ||
|
|
15a5fc0cae | ||
|
|
f07c54d47c | ||
|
|
70446be108 | ||
|
|
d6d21fca56 | ||
|
|
8d7273924f | ||
|
|
ea64afbaa7 | ||
|
|
45da9837ec | ||
|
|
8c19b7d163 | ||
|
|
ab227a08d0 | ||
|
|
40d6e77964 | ||
|
|
9326e3f1b0 | ||
|
|
0e1eb3daf6 | ||
|
|
05daac12ed | ||
|
|
c5b24b4764 | ||
|
|
cc16548e5f | ||
|
|
291d65bb3e | ||
|
|
bd3ad03da6 | ||
|
|
5fa6788357 | ||
|
|
c5c5a98ac4 | ||
|
|
a1151143cf | ||
|
|
f5024984f7 | ||
|
|
f4880fd90d | ||
|
|
0ae61d5865 | ||
|
|
d3bd775a79 | ||
|
|
da546cfe7f | ||
|
|
a211933e83 | ||
|
|
1d40b5a821 | ||
|
|
33836daeb7 | ||
|
|
d921b0f6bd | ||
|
|
0607b95df6 | ||
|
|
0de6d0e046 | ||
|
|
98427345cf | ||
|
|
9fedaa9f77 | ||
|
|
bf4c2ecd33 | ||
|
|
f8c18cc1e0 | ||
|
|
458b900412 | ||
|
|
192c776e0b | ||
|
|
5cdec18863 | ||
|
|
15f856f951 | ||
|
|
01d52cef74 | ||
|
|
95563c8659 | ||
|
|
31d8c40eca | ||
|
|
56001ed272 | ||
|
|
d916fda04c | ||
|
|
cfae655068 | ||
|
|
5596565ec4 | ||
|
|
afa1aa5d93 | ||
|
|
e98c3d8393 | ||
|
|
6687b816f0 | ||
|
|
ea8035e854 | ||
|
|
54b0171d49 | ||
|
|
676d4277b9 | ||
|
|
a4b1da3ca2 | ||
|
|
9e9c16e770 | ||
|
|
dc87006fed | ||
|
|
b9b260f26a | ||
|
|
33fd6a5016 | ||
|
|
97cbccc2ba | ||
|
|
1ee4685d5d | ||
|
|
aba18232b1 | ||
|
|
0a02441b75 | ||
|
|
1be5b4c7ff | ||
|
|
a0ce0cf18a | ||
|
|
7c54e5d093 | ||
|
|
b825e51dab | ||
|
|
589855c393 | ||
|
|
4c546f2f53 | ||
|
|
3753fce912 | ||
|
|
4c02857ec5 | ||
|
|
33f87ff7d7 | ||
|
|
784dcf2a9a | ||
|
|
43ee943acb | ||
|
|
a769fd7d13 | ||
|
|
2c4fd00b16 | ||
|
|
264771fe98 | ||
|
|
ecd92dafef | ||
|
|
c8b6e4bea3 | ||
|
|
3756cb766e | ||
|
|
068d9ca60b | ||
|
|
93f632d8b8 | ||
|
|
bb44ce7e74 | ||
|
|
6986c8d8f7 | ||
|
|
fe95506db4 | ||
|
|
310ed76b18 | ||
|
|
98830d147f | ||
|
|
19c9177d7b | ||
|
|
f41c5f97f6 | ||
|
|
648c125697 | ||
|
|
0dc2b89897 | ||
|
|
83745f83a5 | ||
|
|
2f91fe4535 | ||
|
|
739f09059e | ||
|
|
c86f9f0f5f | ||
|
|
9470ca6bc5 | ||
|
|
2a92c4d5de | ||
|
|
bb6e892657 | ||
|
|
c9079b9299 | ||
|
|
b6963c1bf9 | ||
|
|
9c29df47bb | ||
|
|
fc146d3d00 | ||
|
|
1bf5a21678 | ||
|
|
011542dc2b | ||
|
|
489784104e | ||
|
|
3860634fd2 | ||
|
|
709c324e18 | ||
|
|
b75d24d92c | ||
|
|
ed80e9424c | ||
|
|
2fe1f2060a | ||
|
|
c6df820164 | ||
|
|
d6239822db | ||
|
|
bced9ffff9 | ||
|
|
d7d1c1544a | ||
|
|
7c1e8ce48c | ||
|
|
44dbe475af | ||
|
|
bd24cf3ea4 | ||
|
|
b493a808fe | ||
|
|
54035d108d | ||
|
|
c5e8bc7e20 | ||
|
|
3bbb4779a3 | ||
|
|
1b3963ebea | ||
|
|
e8ffebc006 | ||
|
|
2ca95eaa9f | ||
|
|
0dc5b4cdfc | ||
|
|
cc6cd96d8e | ||
|
|
4244d37625 | ||
|
|
0b766095d4 | ||
|
|
39693a27e3 | ||
|
|
b03fe438d0 | ||
|
|
7f56824b42 | ||
|
|
627da3a2bc | ||
|
|
e6b69042de |
@@ -17,4 +17,8 @@ ENV/
|
||||
.conda/
|
||||
README*.md
|
||||
dashboard/
|
||||
data/
|
||||
data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
.astrbot
|
||||
15
.github/FUNDING.yml
vendored
Normal file
15
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: astrbot
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
polar: # Replace with a single Polar username
|
||||
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||
thanks_dev: # Replace with a single thanks.dev username
|
||||
custom: ['https://afdian.com/a/astrbot_team']
|
||||
80
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
80
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -1,40 +1,56 @@
|
||||
name: '🥳 发布插件'
|
||||
title: "[Plugin] 插件名"
|
||||
name: 🥳 发布插件
|
||||
description: 提交插件到插件市场
|
||||
labels: [ "plugin-publish" ]
|
||||
title: "[Plugin] 插件名"
|
||||
labels: ["plugin-publish"]
|
||||
assignees: []
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 插件仓库
|
||||
description: 插件的 GitHub 仓库链接
|
||||
placeholder: >
|
||||
如 https://github.com/Soulter/astrbot-github-cards
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 描述
|
||||
value: |
|
||||
插件名:
|
||||
插件作者:
|
||||
插件简介:
|
||||
支持的消息平台:(必填,如 QQ、微信、飞书)
|
||||
标签:(可选)
|
||||
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
|
||||
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
options:
|
||||
- label: >
|
||||
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
欢迎发布插件到插件市场!
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "❤️"
|
||||
value: |
|
||||
## 插件基本信息
|
||||
|
||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
|
||||
|
||||
- type: textarea
|
||||
id: plugin-info
|
||||
attributes:
|
||||
label: 插件信息
|
||||
description: 请在下方代码块中填写您的插件信息,确保反引号包裹了JSON
|
||||
value: |
|
||||
```json
|
||||
{
|
||||
"name": "插件名",
|
||||
"desc": "插件介绍",
|
||||
"author": "作者名",
|
||||
"repo": "插件仓库链接",
|
||||
"tags": [],
|
||||
"social_link": ""
|
||||
}
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## 检查
|
||||
|
||||
- type: checkboxes
|
||||
id: checks
|
||||
attributes:
|
||||
label: 插件检查清单
|
||||
description: 请确认以下所有项目
|
||||
options:
|
||||
- label: 我的插件经过完整的测试
|
||||
required: true
|
||||
- label: 我的插件不包含恶意代码
|
||||
required: true
|
||||
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
|
||||
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,5 +1,5 @@
|
||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||
修复了 #XYZ
|
||||
解决了 #XYZ
|
||||
|
||||
### Motivation
|
||||
|
||||
@@ -10,5 +10,10 @@
|
||||
<!--简单解释你的改动-->
|
||||
|
||||
### Check
|
||||
- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 我新增/修复/优化的功能经过良好的测试
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
||||
|
||||
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 👀 我的更改经过良好的测试
|
||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
||||
- [ ] 😮 我的更改没有引入恶意代码
|
||||
|
||||
63
.github/copilot-instructions.md
vendored
Normal file
63
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
# AstrBot Development Instructions
|
||||
|
||||
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
|
||||
|
||||
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
|
||||
|
||||
## Working Effectively
|
||||
|
||||
### Bootstrap and Install Dependencies
|
||||
- **Python 3.10+ required** - Check `.python-version` file
|
||||
- Install UV package manager: `pip install uv`
|
||||
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
|
||||
- Create required directories: `mkdir -p data/plugins data/config data/temp`
|
||||
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
- Navigate to dashboard: `cd dashboard`
|
||||
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
|
||||
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
|
||||
- Dashboard creates optimized production build in `dashboard/dist/`
|
||||
|
||||
### Testing
|
||||
- Do not generate test files for now.
|
||||
|
||||
### Code Quality and Linting
|
||||
- Install ruff linter: `uv add --dev ruff`
|
||||
- Check code style: `uv run ruff check .` -- takes <1 second
|
||||
- Check formatting: `uv run ruff format --check .` -- takes <1 second
|
||||
- Fix formatting: `uv run ruff format .`
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
### Common Issues and Workarounds
|
||||
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
|
||||
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
|
||||
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
|
||||
|
||||
## CI/CD Integration
|
||||
- GitHub Actions workflows in `.github/workflows/`
|
||||
- Docker builds supported via `Dockerfile`
|
||||
- Pre-commit hooks enforce ruff formatting and linting
|
||||
|
||||
## Docker Support
|
||||
- Primary deployment method: `docker run soulter/astrbot:latest`
|
||||
- Compose file available: `compose.yml`
|
||||
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
|
||||
- Volume mount required: `./data:/AstrBot/data`
|
||||
|
||||
## Multi-language Support
|
||||
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
|
||||
- UI supports internationalization
|
||||
- Default language is Chinese
|
||||
|
||||
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.
|
||||
13
.github/dependabot.yml
vendored
Normal file
13
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# Keep GitHub Actions up to date with GitHub's Dependabot...
|
||||
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
|
||||
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: github-actions
|
||||
directory: /
|
||||
groups:
|
||||
github-actions:
|
||||
patterns:
|
||||
- "*" # Group all Actions updates into a single larger pull request
|
||||
schedule:
|
||||
interval: weekly
|
||||
65
.github/workflows/auto_release.yml
vendored
65
.github/workflows/auto_release.yml
vendored
@@ -7,13 +7,13 @@ on:
|
||||
name: Auto Release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
build-and-publish-to-github-release:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Dashboard Build
|
||||
run: |
|
||||
@@ -23,13 +23,70 @@ jobs:
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
echo ${{ github.ref_name }} > dist/assets/version
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Upload to Cloudflare R2
|
||||
env:
|
||||
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
|
||||
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
|
||||
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
|
||||
R2_BUCKET_NAME: "astrbot"
|
||||
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
|
||||
VERSION_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
echo "Installing rclone..."
|
||||
curl https://rclone.org/install.sh | sudo bash
|
||||
|
||||
echo "Configuring rclone remote..."
|
||||
mkdir -p ~/.config/rclone
|
||||
cat <<EOF > ~/.config/rclone/rclone.conf
|
||||
[r2]
|
||||
type = s3
|
||||
provider = Cloudflare
|
||||
access_key_id = $R2_ACCESS_KEY_ID
|
||||
secret_access_key = $R2_SECRET_ACCESS_KEY
|
||||
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
|
||||
EOF
|
||||
|
||||
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
|
||||
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
|
||||
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
|
||||
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
|
||||
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
|
||||
|
||||
- name: Fetch Changelog
|
||||
run: |
|
||||
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Create Release
|
||||
- name: Create GitHub Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
bodyFile: ${{ env.changelog }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
artifacts: "dashboard/dist.zip"
|
||||
|
||||
build-and-publish-to-pypi:
|
||||
# 构建并发布到 PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
python -m pip install uv
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
uv build
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
uv publish
|
||||
|
||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
|
||||
24
.github/workflows/coverage_test.yml
vendored
24
.github/workflows/coverage_test.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
@@ -8,6 +8,7 @@ on:
|
||||
- 'README.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -16,30 +17,29 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
pip install pytest pytest-asyncio pytest-cov
|
||||
pip install --editable .
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir data/temp
|
||||
mkdir -p data/plugins
|
||||
mkdir -p data/config
|
||||
mkdir -p data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
|
||||
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
20
.github/workflows/dashboard_ci.yml
vendored
20
.github/workflows/dashboard_ci.yml
vendored
@@ -1,13 +1,17 @@
|
||||
name: AstrBot Dashboard CI
|
||||
|
||||
on: [push]
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches: [ "master" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: npm install, build
|
||||
run: |
|
||||
@@ -21,6 +25,8 @@ jobs:
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
mkdir -p dashboard/dist/assets
|
||||
echo $COMMIT_SHA > dashboard/dist/assets/version
|
||||
cd dashboard
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
@@ -29,3 +35,13 @@ jobs:
|
||||
path: |
|
||||
dashboard/dist
|
||||
!dist/**/*.md
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
tag: release-${{ github.sha }}
|
||||
owner: AstrBotDevs
|
||||
repo: astrbot-release-harbour
|
||||
body: "Automated release from commit ${{ github.sha }}"
|
||||
token: ${{ secrets.ASTRBOT_HARBOUR_TOKEN }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
66
.github/workflows/docker-image.yml
vendored
66
.github/workflows/docker-image.yml
vendored
@@ -11,33 +11,79 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: 拉取源码
|
||||
uses: actions/checkout@v3
|
||||
- name: Pull The Codes
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
|
||||
- name: 设置 QEMU
|
||||
- name: Get latest tag (only on manual trigger)
|
||||
id: get-latest-tag
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: |
|
||||
tag=$(git describe --tags --abbrev=0)
|
||||
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout to latest tag (only on manual trigger)
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||
|
||||
- name: Check if version is pre-release
|
||||
id: check-prerelease
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
|
||||
else
|
||||
version="${{ github.ref_name }}"
|
||||
fi
|
||||
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
|
||||
echo "is_prerelease=true" >> $GITHUB_OUTPUT
|
||||
echo "Version $version is a pre-release, will not push latest tag"
|
||||
else
|
||||
echo "is_prerelease=false" >> $GITHUB_OUTPUT
|
||||
echo "Version $version is a stable release, will push latest tag"
|
||||
fi
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
mkdir -p dist/assets
|
||||
echo $(git rev-parse HEAD) > dist/assets/version
|
||||
cd ..
|
||||
mkdir -p data
|
||||
cp -r dashboard/dist data/
|
||||
|
||||
- name: Set QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: 设置 Docker Buildx
|
||||
- name: Set Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: 登录到 DockerHub
|
||||
- name: Log in to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: 构建和推送 Docker hub
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: Soulter
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and Push Docker to DockerHub and Github GHCR
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
|
||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
|
||||
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'Stale issue message'
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,3 +30,4 @@ packages/python_interpreter/workplace
|
||||
.conda/
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
@@ -1,9 +1,11 @@
|
||||
FROM python:3.10-slim
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
nodejs \
|
||||
npm \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
@@ -28,3 +30,6 @@ EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
|
||||
|
||||
|
||||
194
README.md
194
README.md
@@ -1,23 +1,21 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||

|
||||
|
||||

|
||||

|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
@@ -25,48 +23,52 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
||||
|
||||
[](https://gitcode.com/Soulter/AstrBot)
|
||||
## 主要功能
|
||||
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
|
||||
## ✨ 主要功能
|
||||
|
||||
> [!NOTE]
|
||||
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
|
||||
> [!TIP]
|
||||
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。
|
||||
|
||||
## ✨ 使用方式
|
||||
## 部署方式
|
||||
|
||||
#### Docker 部署
|
||||
|
||||
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
AstrBot 与宝塔面板合作,已上架至宝塔面板。
|
||||
|
||||
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||
|
||||
#### 1Panel 部署
|
||||
|
||||
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
|
||||
|
||||
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
|
||||
|
||||
#### 在 雨云 上部署
|
||||
|
||||
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### 在 Replit 上部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### Windows 一键安装器部署
|
||||
|
||||
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||
|
||||
#### 宝塔面板部署
|
||||
|
||||
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
@@ -75,55 +77,83 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
#### 手动部署
|
||||
|
||||
推荐使用 `uv`。
|
||||
首先安装 uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
通过 Git Clone 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
pip install uv
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### Replit 部署
|
||||
## 🌍 社区
|
||||
|
||||
### QQ 群组
|
||||
|
||||
- 1 群:322154837
|
||||
- 3 群:630166526
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:753075035
|
||||
- 开发者群(备份):295657329
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
### Discord 群组
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
| 平台 | 支持性 | 详情 | 消息类型 |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方机器人接口) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企业微信 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
| 钉钉 | ✔ |
|
||||
| Slack | ✔ |
|
||||
| Discord | ✔ |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
|
||||
## ⚡ 提供商支持情况
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
|
||||
| Claude API | ✔ | 文本生成 | |
|
||||
| Google Gemini API | ✔ | 文本生成 | |
|
||||
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | 文本生成 | |
|
||||
| Google Gemini | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| DashScope(阿里云百炼应用) | ✔ | LLMOps | |
|
||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -143,34 +173,6 @@ pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## 🌟 支持
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
|
||||
_✨ 插件系统——部分插件展示 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
|
||||
|
||||
_✨ WebUI ✨_
|
||||
|
||||
</div>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
@@ -180,6 +182,10 @@ _✨ WebUI ✨_
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
此外,本项目的诞生离不开以下开源项目:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
@@ -192,11 +198,9 @@ _✨ WebUI ✨_
|
||||
|
||||
</div>
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. The project is protected under the `AGPL-v3` opensource license.
|
||||
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
|
||||
3. Please ensure compliance with local laws and regulations when using this project.
|
||||
</details>
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
15
README_ja.md
15
README_ja.md
@@ -1,5 +1,5 @@
|
||||
<p align="center">
|
||||
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
## ✨ 主な機能
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
@@ -35,7 +35,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
|
||||
> [!TIP]
|
||||
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
>
|
||||
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
|
||||
|
||||
## ✨ 使用方法
|
||||
@@ -136,11 +136,11 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> [!TIP]
|
||||
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
</div>
|
||||
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
## 免責事項
|
||||
|
||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||
2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
|
||||
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
|
||||
<!-- ## ✨ ATRI [ベータテスト]
|
||||
|
||||
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
@@ -3,5 +3,18 @@ from astrbot import logger
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_agent as agent
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"agent",
|
||||
"sp",
|
||||
"ToolSet",
|
||||
"FunctionTool",
|
||||
"BaseFunctionToolExecutor",
|
||||
]
|
||||
|
||||
1
astrbot/cli/__init__.py
Normal file
1
astrbot/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "3.5.23"
|
||||
59
astrbot/cli/__main__.py
Normal file
59
astrbot/cli/__main__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
AstrBot CLI入口
|
||||
"""
|
||||
|
||||
import click
|
||||
import sys
|
||||
from . import __version__
|
||||
from .commands import init, run, plug, conf
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(__version__, prog_name="AstrBot")
|
||||
def cli() -> None:
|
||||
"""The AstrBot CLI"""
|
||||
click.echo(logo_tmpl)
|
||||
click.echo("Welcome to AstrBot CLI!")
|
||||
click.echo(f"AstrBot CLI version: {__version__}")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("command_name", required=False, type=str)
|
||||
def help(command_name: str | None) -> None:
|
||||
"""显示命令的帮助信息
|
||||
|
||||
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
|
||||
否则,显示通用帮助信息。
|
||||
"""
|
||||
ctx = click.get_current_context()
|
||||
if command_name:
|
||||
# 查找指定命令
|
||||
command = cli.get_command(ctx, command_name)
|
||||
if command:
|
||||
# 显示特定命令的帮助信息
|
||||
click.echo(command.get_help(ctx))
|
||||
else:
|
||||
click.echo(f"Unknown command: {command_name}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# 显示通用帮助信息
|
||||
click.echo(cli.get_help(ctx))
|
||||
|
||||
|
||||
cli.add_command(init)
|
||||
cli.add_command(run)
|
||||
cli.add_command(help)
|
||||
cli.add_command(plug)
|
||||
cli.add_command(conf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
6
astrbot/cli/commands/__init__.py
Normal file
6
astrbot/cli/commands/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .cmd_init import init
|
||||
from .cmd_run import run
|
||||
from .cmd_plug import plug
|
||||
from .cmd_conf import conf
|
||||
|
||||
__all__ = ["init", "run", "plug", "conf"]
|
||||
206
astrbot/cli/commands/cmd_conf.py
Normal file
206
astrbot/cli/commands/cmd_conf.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import json
|
||||
import click
|
||||
import hashlib
|
||||
import zoneinfo
|
||||
from typing import Any, Callable
|
||||
from ..utils import get_astrbot_root, check_astrbot_root
|
||||
|
||||
|
||||
def _validate_log_level(value: str) -> str:
|
||||
"""验证日志级别"""
|
||||
value = value.upper()
|
||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise click.ClickException(
|
||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_port(value: str) -> int:
|
||||
"""验证 Dashboard 端口"""
|
||||
try:
|
||||
port = int(value)
|
||||
if port < 1 or port > 65535:
|
||||
raise click.ClickException("端口必须在 1-65535 范围内")
|
||||
return port
|
||||
except ValueError:
|
||||
raise click.ClickException("端口必须是数字")
|
||||
|
||||
|
||||
def _validate_dashboard_username(value: str) -> str:
|
||||
"""验证 Dashboard 用户名"""
|
||||
if not value:
|
||||
raise click.ClickException("用户名不能为空")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_password(value: str) -> str:
|
||||
"""验证 Dashboard 密码"""
|
||||
if not value:
|
||||
raise click.ClickException("密码不能为空")
|
||||
return hashlib.md5(value.encode()).hexdigest()
|
||||
|
||||
|
||||
def _validate_timezone(value: str) -> str:
|
||||
"""验证时区"""
|
||||
try:
|
||||
zoneinfo.ZoneInfo(value)
|
||||
except Exception:
|
||||
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_callback_api_base(value: str) -> str:
|
||||
"""验证回调接口基址"""
|
||||
if not value.startswith("http://") and not value.startswith("https://"):
|
||||
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
|
||||
return value
|
||||
|
||||
|
||||
# 可通过CLI设置的配置项,配置键到验证器函数的映射
|
||||
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
"timezone": _validate_timezone,
|
||||
"log_level": _validate_log_level,
|
||||
"dashboard.port": _validate_dashboard_port,
|
||||
"dashboard.username": _validate_dashboard_username,
|
||||
"dashboard.password": _validate_dashboard_password,
|
||||
"callback_api_base": _validate_callback_api_base,
|
||||
}
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
"""加载或初始化配置文件"""
|
||||
root = get_astrbot_root()
|
||||
if not check_astrbot_root(root):
|
||||
raise click.ClickException(
|
||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
config_path = root / "data" / "cmd_config.json"
|
||||
if not config_path.exists():
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.ClickException(f"配置文件解析失败: {str(e)}")
|
||||
|
||||
|
||||
def _save_config(config: dict[str, Any]) -> None:
|
||||
"""保存配置文件"""
|
||||
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
||||
)
|
||||
|
||||
|
||||
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""设置嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts[:-1]:
|
||||
if part not in obj:
|
||||
obj[part] = {}
|
||||
elif not isinstance(obj[part], dict):
|
||||
raise click.ClickException(
|
||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
|
||||
)
|
||||
obj = obj[part]
|
||||
obj[parts[-1]] = value
|
||||
|
||||
|
||||
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
"""获取嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj
|
||||
|
||||
|
||||
@click.group(name="conf")
|
||||
def conf():
|
||||
"""配置管理命令
|
||||
|
||||
支持的配置项:
|
||||
|
||||
- timezone: 时区设置 (例如: Asia/Shanghai)
|
||||
|
||||
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
|
||||
- dashboard.port: Dashboard 端口
|
||||
|
||||
- dashboard.username: Dashboard 用户名
|
||||
|
||||
- dashboard.password: Dashboard 密码
|
||||
|
||||
- callback_api_base: 回调接口基址
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str):
|
||||
"""设置配置项的值"""
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
config = _load_config()
|
||||
|
||||
try:
|
||||
old_value = _get_nested_item(config, key)
|
||||
validated_value = CONFIG_VALIDATORS[key](value)
|
||||
_set_nested_item(config, key, validated_value)
|
||||
_save_config(config)
|
||||
|
||||
click.echo(f"配置已更新: {key}")
|
||||
if key == "dashboard.password":
|
||||
click.echo(" 原值: ********")
|
||||
click.echo(" 新值: ********")
|
||||
else:
|
||||
click.echo(f" 原值: {old_value}")
|
||||
click.echo(f" 新值: {validated_value}")
|
||||
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"设置配置失败: {str(e)}")
|
||||
|
||||
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str = None):
|
||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||
config = _load_config()
|
||||
|
||||
if key:
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
try:
|
||||
value = _get_nested_item(config, key)
|
||||
if key == "dashboard.password":
|
||||
value = "********"
|
||||
click.echo(f"{key}: {value}")
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"获取配置失败: {str(e)}")
|
||||
else:
|
||||
click.echo("当前配置:")
|
||||
for key in CONFIG_VALIDATORS.keys():
|
||||
try:
|
||||
value = (
|
||||
"********"
|
||||
if key == "dashboard.password"
|
||||
else _get_nested_item(config, key)
|
||||
)
|
||||
click.echo(f" {key}: {value}")
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
55
astrbot/cli/commands/cmd_init.py
Normal file
55
astrbot/cli/commands/cmd_init.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def initialize_astrbot(astrbot_root) -> None:
|
||||
"""执行 AstrBot 初始化逻辑"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(f"Current Directory: {astrbot_root}")
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||
)
|
||||
if click.confirm(
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
|
||||
paths = {
|
||||
"data": astrbot_root / "data",
|
||||
"config": astrbot_root / "data" / "config",
|
||||
"plugins": astrbot_root / "data" / "plugins",
|
||||
"temp": astrbot_root / "data" / "temp",
|
||||
}
|
||||
|
||||
for name, path in paths.items():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
|
||||
@click.command()
|
||||
def init() -> None:
|
||||
"""初始化 AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
astrbot_root = get_astrbot_root()
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
|
||||
try:
|
||||
with lock.acquire():
|
||||
asyncio.run(initialize_astrbot(astrbot_root))
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"初始化失败: {e!s}")
|
||||
247
astrbot/cli/commands/cmd_plug.py
Normal file
247
astrbot/cli/commands/cmd_plug.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import shutil
|
||||
|
||||
|
||||
from ..utils import (
|
||||
get_git_repo,
|
||||
build_plug_list,
|
||||
manage_plugin,
|
||||
PluginStatus,
|
||||
check_astrbot_root,
|
||||
get_astrbot_root,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
def plug():
|
||||
"""插件管理"""
|
||||
pass
|
||||
|
||||
|
||||
def _get_data_path() -> Path:
|
||||
base = get_astrbot_root()
|
||||
if not check_astrbot_root(base):
|
||||
raise click.ClickException(
|
||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
return (base / "data").resolve()
|
||||
|
||||
|
||||
def display_plugins(plugins, title=None, color=None):
|
||||
if title:
|
||||
click.echo(click.style(title, fg=color, bold=True))
|
||||
|
||||
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
|
||||
click.echo("-" * 85)
|
||||
|
||||
for p in plugins:
|
||||
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
||||
click.echo(
|
||||
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
||||
f"{p['author']:<15} {desc:<30}"
|
||||
)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def new(name: str):
|
||||
"""创建新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
|
||||
if plug_path.exists():
|
||||
raise click.ClickException(f"插件 {name} 已存在")
|
||||
|
||||
author = click.prompt("请输入插件作者", type=str)
|
||||
desc = click.prompt("请输入插件描述", type=str)
|
||||
version = click.prompt("请输入插件版本", type=str)
|
||||
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
|
||||
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
|
||||
repo = click.prompt("请输入插件仓库:", type=str)
|
||||
if not repo.startswith("http"):
|
||||
raise click.ClickException("仓库地址必须以 http 开头")
|
||||
|
||||
click.echo("下载插件模板...")
|
||||
get_git_repo(
|
||||
"https://github.com/Soulter/helloworld",
|
||||
plug_path,
|
||||
)
|
||||
|
||||
click.echo("重写插件信息...")
|
||||
# 重写 metadata.yaml
|
||||
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"name: {name}\n"
|
||||
f"desc: {desc}\n"
|
||||
f"version: {version}\n"
|
||||
f"author: {author}\n"
|
||||
f"repo: {repo}\n"
|
||||
)
|
||||
|
||||
# 重写 README.md
|
||||
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
|
||||
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||
|
||||
# 重写 main.py
|
||||
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
new_content = content.replace(
|
||||
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
|
||||
f'@register("{name}", "{author}", "{desc}", "{version}")',
|
||||
)
|
||||
|
||||
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
click.echo(f"插件 {name} 创建成功")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
|
||||
def list(all: bool):
|
||||
"""列出插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
# 未发布的插件
|
||||
not_published_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
|
||||
]
|
||||
if not_published_plugins:
|
||||
display_plugins(not_published_plugins, "未发布的插件", "red")
|
||||
|
||||
# 需要更新的插件
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
if need_update_plugins:
|
||||
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
|
||||
|
||||
# 已安装的插件
|
||||
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
|
||||
if installed_plugins:
|
||||
display_plugins(installed_plugins, "已安装的插件", "green")
|
||||
|
||||
# 未安装的插件
|
||||
not_installed_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
|
||||
]
|
||||
if not_installed_plugins and all:
|
||||
display_plugins(not_installed_plugins, "未安装的插件", "blue")
|
||||
|
||||
if (
|
||||
not any([not_published_plugins, need_update_plugins, installed_plugins])
|
||||
and not all
|
||||
):
|
||||
click.echo("未安装任何插件")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
@click.option("--proxy", help="代理服务器地址")
|
||||
def install(name: str, proxy: str | None):
|
||||
"""安装插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def remove(name: str):
|
||||
"""卸载插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
plugin = next((p for p in plugins if p["name"] == name), None)
|
||||
|
||||
if not plugin or not plugin.get("local_path"):
|
||||
raise click.ClickException(f"插件 {name} 不存在或未安装")
|
||||
|
||||
plugin_path = plugin["local_path"]
|
||||
|
||||
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
|
||||
|
||||
try:
|
||||
shutil.rmtree(plugin_path)
|
||||
click.echo(f"插件 {name} 已卸载")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name", required=False)
|
||||
@click.option("--proxy", help="Github代理地址")
|
||||
def update(name: str, proxy: str | None):
|
||||
"""更新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
if name:
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
else:
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
|
||||
if not need_update_plugins:
|
||||
click.echo("没有需要更新的插件")
|
||||
return
|
||||
|
||||
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
|
||||
for plugin in need_update_plugins:
|
||||
plugin_name = plugin["name"]
|
||||
click.echo(f"正在更新插件 {plugin_name}...")
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("query")
|
||||
def search(query: str):
|
||||
"""搜索插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
matched_plugins = [
|
||||
p
|
||||
for p in plugins
|
||||
if query.lower() in p["name"].lower()
|
||||
or query.lower() in p["desc"].lower()
|
||||
or query.lower() in p["author"].lower()
|
||||
]
|
||||
|
||||
if not matched_plugins:
|
||||
click.echo(f"未找到匹配 '{query}' 的插件")
|
||||
return
|
||||
|
||||
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
|
||||
63
astrbot/cli/commands/cmd_run.py
Normal file
63
astrbot/cli/commands/cmd_run.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import logger, LogManager, LogBroker, db_helper
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
db = db_helper
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
|
||||
await core_lifecycle.start()
|
||||
|
||||
|
||||
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
|
||||
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
|
||||
@click.command()
|
||||
def run(reload: bool, port: str) -> None:
|
||||
"""运行 AstrBot"""
|
||||
try:
|
||||
os.environ["ASTRBOT_CLI"] = "1"
|
||||
astrbot_root = get_astrbot_root()
|
||||
|
||||
if not check_astrbot_root(astrbot_root):
|
||||
raise click.ClickException(
|
||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||
sys.path.insert(0, str(astrbot_root))
|
||||
|
||||
if port:
|
||||
os.environ["DASHBOARD_PORT"] = port
|
||||
|
||||
if reload:
|
||||
click.echo("启用插件自动重载")
|
||||
os.environ["ASTRBOT_RELOAD"] = "1"
|
||||
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
with lock.acquire():
|
||||
asyncio.run(run_astrbot(astrbot_root))
|
||||
except KeyboardInterrupt:
|
||||
click.echo("AstrBot 已关闭...")
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
|
||||
18
astrbot/cli/utils/__init__.py
Normal file
18
astrbot/cli/utils/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from .basic import (
|
||||
get_astrbot_root,
|
||||
check_astrbot_root,
|
||||
check_dashboard,
|
||||
)
|
||||
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
__all__ = [
|
||||
"get_astrbot_root",
|
||||
"check_astrbot_root",
|
||||
"check_dashboard",
|
||||
"get_git_repo",
|
||||
"manage_plugin",
|
||||
"build_plug_list",
|
||||
"VersionComparator",
|
||||
"PluginStatus",
|
||||
]
|
||||
76
astrbot/cli/utils/basic.py
Normal file
76
astrbot/cli/utils/basic.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def check_astrbot_root(path: str | Path) -> bool:
|
||||
"""检查路径是否为 AstrBot 根目录"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if not path.exists() or not path.is_dir():
|
||||
return False
|
||||
if not (path / ".astrbot").exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""获取Astrbot根目录路径"""
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
"""检查是否安装了dashboard"""
|
||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||
from astrbot.core.config.default import VERSION
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
try:
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("未安装管理面板")
|
||||
if click.confirm(
|
||||
"是否安装管理面板?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在安装管理面板...")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
case str():
|
||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||
click.echo("管理面板已是最新版本")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("初始化管理面板目录...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"),
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
233
astrbot/cli/utils/plugin.py
Normal file
233
astrbot/cli/utils/plugin.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import click
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
class PluginStatus(str, Enum):
|
||||
INSTALLED = "已安装"
|
||||
NEED_UPDATE = "需更新"
|
||||
NOT_INSTALLED = "未安装"
|
||||
NOT_PUBLISHED = "未发布"
|
||||
|
||||
|
||||
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
"""从 Git 仓库下载代码并解压到指定路径"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
# 解析仓库信息
|
||||
repo_namespace = url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
# 尝试获取最新的 release
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
try:
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(release_url)
|
||||
resp.raise_for_status()
|
||||
releases = resp.json()
|
||||
|
||||
if releases:
|
||||
# 使用最新的 release
|
||||
download_url = releases[0]["zipball_url"]
|
||||
else:
|
||||
# 没有 release,使用默认分支
|
||||
click.echo(f"正在从默认分支下载 {author}/{repo}")
|
||||
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
except Exception as e:
|
||||
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
|
||||
download_url = url
|
||||
|
||||
# 应用代理
|
||||
if proxy:
|
||||
download_url = f"{proxy}/{download_url}"
|
||||
|
||||
# 下载并解压
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(download_url)
|
||||
if (
|
||||
resp.status_code == 404
|
||||
and "archive/refs/heads/master.zip" in download_url
|
||||
):
|
||||
alt_url = download_url.replace("master.zip", "main.zip")
|
||||
click.echo("master 分支不存在,尝试下载 main 分支")
|
||||
resp = client.get(alt_url)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
zip_content = BytesIO(resp.content)
|
||||
with ZipFile(zip_content) as z:
|
||||
z.extractall(temp_dir)
|
||||
namelist = z.namelist()
|
||||
root_dir = Path(namelist[0]).parts[0] if namelist else ""
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path)
|
||||
shutil.move(temp_dir / root_dir, target_path)
|
||||
finally:
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
"""从 metadata.yaml 文件加载插件元数据
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录路径
|
||||
|
||||
Returns:
|
||||
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||
"""
|
||||
yaml_path = plugin_dir / "metadata.yaml"
|
||||
if yaml_path.exists():
|
||||
try:
|
||||
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
|
||||
except Exception as e:
|
||||
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
|
||||
return {}
|
||||
|
||||
|
||||
def build_plug_list(plugins_dir: Path) -> list:
|
||||
"""构建插件列表,包含本地和在线插件信息
|
||||
|
||||
Args:
|
||||
plugins_dir (Path): 插件目录路径
|
||||
|
||||
Returns:
|
||||
list: 包含插件信息的字典列表
|
||||
"""
|
||||
# 获取本地插件信息
|
||||
result = []
|
||||
if plugins_dir.exists():
|
||||
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
|
||||
plugin_dir = plugins_dir / plugin_name
|
||||
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
if "desc" not in metadata and "description" in metadata:
|
||||
metadata["desc"] = metadata["description"]
|
||||
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
resp = client.get("https://api.soulter.top/astrbot/plugins")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
# 与在线插件比对,更新状态
|
||||
online_plugin_names = {plugin["name"] for plugin in online_plugins}
|
||||
for local_plugin in result:
|
||||
if local_plugin["name"] in online_plugin_names:
|
||||
# 查找对应的在线插件
|
||||
online_plugin = next(
|
||||
p for p in online_plugins if p["name"] == local_plugin["name"]
|
||||
)
|
||||
if (
|
||||
VersionComparator.compare_version(
|
||||
local_plugin["version"], online_plugin["version"]
|
||||
)
|
||||
< 0
|
||||
):
|
||||
local_plugin["status"] = PluginStatus.NEED_UPDATE
|
||||
else:
|
||||
# 本地插件未在线上发布
|
||||
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
|
||||
|
||||
# 添加未安装的在线插件
|
||||
for online_plugin in online_plugins:
|
||||
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
|
||||
result.append(online_plugin)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def manage_plugin(
|
||||
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
|
||||
) -> None:
|
||||
"""安装或更新插件
|
||||
|
||||
Args:
|
||||
plugin (dict): 插件信息字典
|
||||
plugins_dir (Path): 插件目录
|
||||
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||
proxy (str, optional): 代理服务器地址
|
||||
"""
|
||||
plugin_name = plugin["name"]
|
||||
repo_url = plugin["repo"]
|
||||
|
||||
# 如果是更新且有本地路径,直接使用本地路径
|
||||
if is_update and plugin.get("local_path"):
|
||||
target_path = Path(plugin["local_path"])
|
||||
else:
|
||||
target_path = plugins_dir / plugin_name
|
||||
|
||||
backup_path = Path(f"{target_path}_backup") if is_update else None
|
||||
|
||||
# 检查插件是否存在
|
||||
if is_update and not target_path.exists():
|
||||
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||
|
||||
# 备份现有插件
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
if is_update:
|
||||
shutil.copytree(target_path, backup_path)
|
||||
|
||||
try:
|
||||
click.echo(
|
||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
||||
)
|
||||
get_git_repo(repo_url, target_path, proxy)
|
||||
|
||||
# 更新成功,删除备份
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||
except Exception as e:
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
if is_update and backup_path.exists():
|
||||
shutil.move(backup_path, target_path)
|
||||
raise click.ClickException(
|
||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
||||
)
|
||||
92
astrbot/cli/utils/version_comparator.py
Normal file
92
astrbot/cli/utils/version_comparator.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
拷贝自 astrbot.core.utils.version_comparator
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class VersionComparator:
|
||||
@staticmethod
|
||||
def compare_version(v1: str, v2: str) -> int:
|
||||
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
|
||||
|
||||
参考: https://semver.org/lang/zh-CN/
|
||||
|
||||
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||
"""
|
||||
v1 = v1.lower().replace("v", "")
|
||||
v2 = v2.lower().replace("v", "")
|
||||
|
||||
def split_version(version):
|
||||
match = re.match(
|
||||
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
|
||||
version,
|
||||
)
|
||||
if not match:
|
||||
return [], None
|
||||
major_minor_patch = match.group(1).split(".")
|
||||
prerelease = match.group(2)
|
||||
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
|
||||
parts = [int(x) for x in major_minor_patch]
|
||||
prerelease = VersionComparator._split_prerelease(prerelease)
|
||||
return parts, prerelease
|
||||
|
||||
v1_parts, v1_prerelease = split_version(v1)
|
||||
v2_parts, v2_prerelease = split_version(v2)
|
||||
|
||||
# 比较数字部分
|
||||
length = max(len(v1_parts), len(v2_parts))
|
||||
v1_parts.extend([0] * (length - len(v1_parts)))
|
||||
v2_parts.extend([0] * (length - len(v2_parts)))
|
||||
|
||||
for i in range(length):
|
||||
if v1_parts[i] > v2_parts[i]:
|
||||
return 1
|
||||
elif v1_parts[i] < v2_parts[i]:
|
||||
return -1
|
||||
|
||||
# 比较预发布标签
|
||||
if v1_prerelease is None and v2_prerelease is not None:
|
||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is None:
|
||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is not None:
|
||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||
for i in range(len_pre):
|
||||
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
|
||||
|
||||
if p1 is None and p2 is not None:
|
||||
return -1
|
||||
elif p1 is not None and p2 is None:
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, str):
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
return 0 # 预发布标签完全相同
|
||||
|
||||
return 0 # 数字部分和预发布标签都相同
|
||||
|
||||
@staticmethod
|
||||
def _split_prerelease(prerelease):
|
||||
if not prerelease:
|
||||
return None
|
||||
parts = prerelease.split(".")
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
result.append(int(part))
|
||||
else:
|
||||
result.append(part)
|
||||
return result
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import asyncio
|
||||
from .log import LogManager, LogBroker # noqa
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
@@ -7,24 +6,24 @@ from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.file_token_service import FileTokenService
|
||||
from .utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs("data", exist_ok=True)
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||
html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = (
|
||||
SharedPreferences()
|
||||
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences(db_helper=db_helper)
|
||||
# 文件令牌服务
|
||||
file_token_service = FileTokenService()
|
||||
pip_installer = PipInstaller(
|
||||
astrbot_config.get("pip_install_arg", ""),
|
||||
astrbot_config.get("pypi_index_url", None),
|
||||
)
|
||||
|
||||
13
astrbot/core/agent/agent.py
Normal file
13
astrbot/core/agent/agent.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from .tool import FunctionTool
|
||||
from typing import Generic
|
||||
from .run_context import TContext
|
||||
from .hooks import BaseAgentRunHooks
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent(Generic[TContext]):
|
||||
name: str
|
||||
instructions: str | None = None
|
||||
tools: list[str, FunctionTool] | None = None
|
||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||
34
astrbot/core/agent/handoff.py
Normal file
34
astrbot/core/agent/handoff.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Generic
|
||||
from .tool import FunctionTool
|
||||
from .agent import Agent
|
||||
from .run_context import TContext
|
||||
|
||||
|
||||
class HandoffTool(FunctionTool, Generic[TContext]):
|
||||
"""Handoff tool for delegating tasks to another agent."""
|
||||
|
||||
def __init__(
|
||||
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
|
||||
):
|
||||
self.agent = agent
|
||||
super().__init__(
|
||||
name=f"transfer_to_{agent.name}",
|
||||
parameters=parameters or self.default_parameters(),
|
||||
description=agent.instructions or self.default_description(agent.name),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def default_parameters(self) -> dict:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "The input to be handed off to another agent. This should be a clear and concise request or task.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def default_description(self, agent_name: str | None) -> str:
|
||||
agent_name = agent_name or "another"
|
||||
return f"Delegate tasks to {self.name} agent to handle the request."
|
||||
27
astrbot/core/agent/hooks.py
Normal file
27
astrbot/core/agent/hooks.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import mcp
|
||||
from dataclasses import dataclass
|
||||
from .run_context import ContextWrapper, TContext
|
||||
from typing import Generic
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseAgentRunHooks(Generic[TContext]):
|
||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
||||
async def on_tool_start(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
): ...
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
tool_result: mcp.types.CallToolResult | None,
|
||||
): ...
|
||||
async def on_agent_done(
|
||||
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
|
||||
): ...
|
||||
208
astrbot/core/agent/mcp_client.py
Normal file
208
astrbot/core/agent/mcp_client.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
if "url" in cfg:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=LogPipe(
|
||||
level=logging.ERROR,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
12
astrbot/core/agent/response.py
Normal file
12
astrbot/core/agent/response.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
import typing as T
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
17
astrbot/core/agent/run_context.py
Normal file
17
astrbot/core/agent/run_context.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
context: TContext
|
||||
event: AstrMessageEvent
|
||||
|
||||
NoContext = ContextWrapper[None]
|
||||
3
astrbot/core/agent/runners/__init__.py
Normal file
3
astrbot/core/agent/runners/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAgentRunner
|
||||
|
||||
__all__ = ["BaseAgentRunner"]
|
||||
58
astrbot/core/agent/runners/base.py
Normal file
58
astrbot/core/agent/runners/base.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import abc
|
||||
import typing as T
|
||||
from enum import Enum, auto
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponse
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Defines the state of the agent."""
|
||||
|
||||
IDLE = auto() # Initial state
|
||||
RUNNING = auto() # Currently processing
|
||||
DONE = auto() # Completed
|
||||
ERROR = auto() # Error state
|
||||
|
||||
|
||||
class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the agent to its initial state.
|
||||
This method should be called before starting a new run.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
Check if the agent has completed its task.
|
||||
Returns True if the agent is done, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
"""
|
||||
Get the final observation from the agent.
|
||||
This method should be called after the agent is done.
|
||||
"""
|
||||
...
|
||||
334
astrbot/core/agent/runners/tool_loop_agent_runner.py
Normal file
334
astrbot/core/agent/runners/tool_loop_agent_runner.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentState
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponseData
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ToolCallMessageSegment,
|
||||
AssistantMessageSegment,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
)
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
@override
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.tool_executor = tool_executor
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
This method should return the result of the step.
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text)
|
||||
),
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
return
|
||||
|
||||
# 处理 LLM 响应
|
||||
llm_resp = llm_resp_result
|
||||
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if not llm_resp.tools_call_name:
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回 LLM 结果
|
||||
if llm_resp.result_chain:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=llm_resp.result_chain),
|
||||
)
|
||||
elif llm_resp.completion_text:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_resp.completion_text)
|
||||
),
|
||||
)
|
||||
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
role="assistant",
|
||||
tool_calls=llm_resp.to_openai_tool_calls(),
|
||||
content=llm_resp.completion_text,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
|
||||
"""处理函数工具调用。"""
|
||||
tool_call_result_blocks: list[ToolCallMessageSegment] = []
|
||||
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
|
||||
|
||||
# 执行函数调用
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, func_tool_args
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
||||
|
||||
executor = self.tool_executor.execute(
|
||||
tool=func_tool,
|
||||
run_context=self.run_context,
|
||||
**func_tool_args,
|
||||
)
|
||||
async for resp in executor:
|
||||
if isinstance(resp, CallToolResult):
|
||||
res = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result"
|
||||
).base64_image(res.content[0].data)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context,
|
||||
func_tool_name,
|
||||
func_tool_args,
|
||||
resp,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.run_context.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
|
||||
self.run_context.event.clear_result()
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
)
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
256
astrbot/core/agent/tool.py
Normal file
256
astrbot/core/agent/tool.py
Normal file
@@ -0,0 +1,256 @@
|
||||
from dataclasses import dataclass
|
||||
from deprecated import deprecated
|
||||
from typing import Awaitable, Literal, Any, Optional
|
||||
from .mcp_client import MCPClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool:
|
||||
"""A class representing a function tool that can be used in function calling."""
|
||||
|
||||
name: str | None = None
|
||||
parameters: dict | None = None
|
||||
description: str | None = None
|
||||
handler: Awaitable | None = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str | None = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str | None = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient | None = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
|
||||
def __dict__(self) -> dict[str, Any]:
|
||||
"""将 FunctionTool 转换为字典格式"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"parameters": self.parameters,
|
||||
"description": self.description,
|
||||
"active": self.active,
|
||||
"origin": self.origin,
|
||||
"mcp_server_name": self.mcp_server_name,
|
||||
}
|
||||
|
||||
|
||||
class ToolSet:
|
||||
"""A set of function tools that can be used in function calling.
|
||||
|
||||
This class provides methods to add, remove, and retrieve tools, as well as
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||
|
||||
def __init__(self, tools: list[FunctionTool] = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
return len(self.tools) == 0
|
||||
|
||||
def add_tool(self, tool: FunctionTool):
|
||||
"""Add a tool to the set."""
|
||||
# 检查是否已存在同名工具
|
||||
for i, existing_tool in enumerate(self.tools):
|
||||
if existing_tool.name == tool.name:
|
||||
self.tools[i] = tool
|
||||
return
|
||||
self.tools.append(tool)
|
||||
|
||||
def remove_tool(self, name: str):
|
||||
"""Remove a tool by its name."""
|
||||
self.tools = [tool for tool in self.tools if tool.name != name]
|
||||
|
||||
def get_tool(self, name: str) -> Optional[FunctionTool]:
|
||||
"""Get a tool by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
||||
"""Add a function tool to the set."""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
_func = FunctionTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
self.add_tool(_func)
|
||||
|
||||
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
|
||||
def remove_func(self, name: str):
|
||||
"""Remove a function tool by its name."""
|
||||
self.remove_tool(name)
|
||||
|
||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||
def get_func(self, name: str) -> list[FunctionTool]:
|
||||
"""Get all function tools."""
|
||||
return self.get_tool(name)
|
||||
|
||||
@property
|
||||
def func_list(self) -> list[FunctionTool]:
|
||||
"""Get the list of function tools."""
|
||||
return self.tools
|
||||
|
||||
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
|
||||
"""Convert tools to OpenAI API function calling schema format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
func_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
},
|
||||
}
|
||||
|
||||
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
return result
|
||||
|
||||
def anthropic_schema(self) -> list[dict]:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
tool_def = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": tool.parameters.get("properties", {}),
|
||||
"required": tool.parameters.get("required", []),
|
||||
},
|
||||
}
|
||||
result.append(tool_def)
|
||||
return result
|
||||
|
||||
def google_schema(self) -> dict:
|
||||
"""Convert tools to Google GenAI API format."""
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""Convert schema to Gemini API format."""
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties:
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": convert_schema(tool.parameters),
|
||||
}
|
||||
for tool in self.tools
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
|
||||
return self.openai_schema(omit_empty_parameter_field)
|
||||
|
||||
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
|
||||
def get_func_desc_anthropic_style(self):
|
||||
return self.anthropic_schema()
|
||||
|
||||
@deprecated(reason="Use google_schema() instead", version="4.0.0")
|
||||
def get_func_desc_google_genai_style(self):
|
||||
return self.google_schema()
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""获取所有工具的名称列表"""
|
||||
return [tool.name for tool in self.tools]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tools)
|
||||
|
||||
def __bool__(self):
|
||||
return len(self.tools) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
|
||||
def __str__(self):
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
11
astrbot/core/agent/tool_executor.py
Normal file
11
astrbot/core/agent/tool_executor.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import mcp
|
||||
from typing import Any, Generic, AsyncGenerator
|
||||
from .run_context import TContext, ContextWrapper
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
class BaseFunctionToolExecutor(Generic[TContext]):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
|
||||
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...
|
||||
11
astrbot/core/astr_agent_context.py
Normal file
11
astrbot/core/astr_agent_context.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
provider: Provider
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
283
astrbot/core/astrbot_config_mgr.py
Normal file
283
astrbot/core/astrbot_config_mgr.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import os
|
||||
import uuid
|
||||
from astrbot.core import AstrBotConfig, logger
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
||||
from typing import TypeVar, TypedDict
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class ConfInfo(TypedDict):
|
||||
"""Configuration information for a specific session or platform."""
|
||||
|
||||
id: str # UUID of the configuration or "default"
|
||||
umop: list[str] # Unified Message Origin Pattern
|
||||
name: str
|
||||
path: str # File name to the configuration file
|
||||
|
||||
|
||||
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||
id="default",
|
||||
umop=["::"],
|
||||
name="default",
|
||||
path=ASTRBOT_CONFIG_PATH,
|
||||
)
|
||||
|
||||
|
||||
class AstrBotConfigManager:
|
||||
"""A class to manage the system configuration of AstrBot, aka ACM"""
|
||||
|
||||
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
|
||||
self.sp = sp
|
||||
self.confs: dict[str, AstrBotConfig] = {}
|
||||
"""uuid / "default" -> AstrBotConfig"""
|
||||
self.confs["default"] = default_config
|
||||
self.abconf_data = None
|
||||
self._load_all_configs()
|
||||
|
||||
def _get_abconf_data(self) -> dict:
|
||||
"""获取所有的 abconf 数据"""
|
||||
if self.abconf_data is None:
|
||||
self.abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
return self.abconf_data
|
||||
|
||||
def _load_all_configs(self):
|
||||
"""Load all configurations from the shared preferences."""
|
||||
abconf_data = self._get_abconf_data()
|
||||
self.abconf_data = abconf_data
|
||||
for uuid_, meta in abconf_data.items():
|
||||
filename = meta["path"]
|
||||
conf_path = os.path.join(get_astrbot_config_path(), filename)
|
||||
if os.path.exists(conf_path):
|
||||
conf = AstrBotConfig(config_path=conf_path)
|
||||
self.confs[uuid_] = conf
|
||||
else:
|
||||
logger.warning(
|
||||
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||
p1_ls = p1.split(":")
|
||||
p2_ls = p2.split(":")
|
||||
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
||||
|
||||
Returns:
|
||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||
"""
|
||||
# uuid -> { "umop": list, "path": str, "name": str }
|
||||
abconf_data = self._get_abconf_data()
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = str(umo)
|
||||
else:
|
||||
try:
|
||||
umo = str(MessageSession.from_str(umo)) # validate
|
||||
except Exception:
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
for uuid_, meta in abconf_data.items():
|
||||
for pattern in meta["umop"]:
|
||||
if self._is_umo_match(pattern, umo):
|
||||
return ConfInfo(**meta, id=uuid_)
|
||||
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
def _save_conf_mapping(
|
||||
self,
|
||||
abconf_path: str,
|
||||
abconf_id: str,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
abconf_name: str | None = None,
|
||||
) -> None:
|
||||
"""保存配置文件的映射关系"""
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||
abconf_data[abconf_id] = {
|
||||
"umop": umo_parts,
|
||||
"path": abconf_path,
|
||||
"name": random_word,
|
||||
}
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
self.abconf_data = abconf_data
|
||||
|
||||
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
|
||||
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
|
||||
if not umo:
|
||||
return self.confs["default"]
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
|
||||
|
||||
uuid_ = self._load_conf_mapping(umo)["id"]
|
||||
|
||||
conf = self.confs.get(uuid_)
|
||||
if not conf:
|
||||
conf = self.confs["default"] # default MUST exists
|
||||
|
||||
return conf
|
||||
|
||||
@property
|
||||
def default_conf(self) -> AstrBotConfig:
|
||||
"""获取默认配置文件"""
|
||||
return self.confs["default"]
|
||||
|
||||
def get_conf_info(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件元数据"""
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
|
||||
|
||||
return self._load_conf_mapping(umo)
|
||||
|
||||
def get_conf_list(self) -> list[ConfInfo]:
|
||||
"""获取所有配置文件的元数据列表"""
|
||||
conf_list = []
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
abconf_mapping = self._get_abconf_data()
|
||||
for uuid_, meta in abconf_mapping.items():
|
||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||
return conf_list
|
||||
|
||||
def create_conf(
|
||||
self,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
config: dict = DEFAULT_CONFIG,
|
||||
name: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
||||
|
||||
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
||||
"""
|
||||
conf_uuid = str(uuid.uuid4())
|
||||
conf_file_name = f"abconf_{conf_uuid}.json"
|
||||
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
|
||||
conf = AstrBotConfig(config_path=conf_path, default_config=config)
|
||||
conf.save_config()
|
||||
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
|
||||
self.confs[conf_uuid] = conf
|
||||
return conf_uuid
|
||||
|
||||
def delete_conf(self, conf_id: str) -> bool:
|
||||
"""删除指定配置文件
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
|
||||
Raises:
|
||||
ValueError: 如果试图删除默认配置文件
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能删除默认配置文件")
|
||||
|
||||
# 从映射中移除
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
|
||||
# 获取配置文件路径
|
||||
conf_path = os.path.join(
|
||||
get_astrbot_config_path(), abconf_data[conf_id]["path"]
|
||||
)
|
||||
|
||||
# 删除配置文件
|
||||
try:
|
||||
if os.path.exists(conf_path):
|
||||
os.remove(conf_path)
|
||||
logger.info(f"已删除配置文件: {conf_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除配置文件 {conf_path} 失败: {e}")
|
||||
return False
|
||||
|
||||
# 从内存中移除
|
||||
if conf_id in self.confs:
|
||||
del self.confs[conf_id]
|
||||
|
||||
# 从映射中移除
|
||||
del abconf_data[conf_id]
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
self.abconf_data = abconf_data
|
||||
|
||||
logger.info(f"成功删除配置文件 {conf_id}")
|
||||
return True
|
||||
|
||||
def update_conf_info(
|
||||
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
|
||||
) -> bool:
|
||||
"""更新配置文件信息
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
name: 新的配置文件名称 (可选)
|
||||
umo_parts: 新的 UMO 部分列表 (可选)
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能更新默认配置文件的信息")
|
||||
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
|
||||
# 更新名称
|
||||
if name is not None:
|
||||
abconf_data[conf_id]["name"] = name
|
||||
|
||||
# 更新 UMO 部分
|
||||
if umo_parts is not None:
|
||||
# 验证 UMO 部分格式
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data[conf_id]["umop"] = umo_parts
|
||||
|
||||
# 保存更新
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
self.abconf_data = abconf_data
|
||||
logger.info(f"成功更新配置文件 {conf_id} 的信息")
|
||||
return True
|
||||
|
||||
def g(
|
||||
self, umo: str | None = None, key: str | None = None, default: _VT = None
|
||||
) -> _VT:
|
||||
"""获取配置项。umo 为 None 时使用默认配置"""
|
||||
if umo is None:
|
||||
return self.confs["default"].get(key, default)
|
||||
conf = self.get_conf(umo)
|
||||
return conf.get(key, default)
|
||||
@@ -4,8 +4,9 @@ import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -42,11 +43,10 @@ class AstrBotConfig(dict):
|
||||
"""不存在时载入默认配置"""
|
||||
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
|
||||
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith("/ufeff"): # remove BOM
|
||||
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
@@ -83,23 +83,61 @@ class AstrBotConfig(dict):
|
||||
return conf
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
"""检查配置完整性,如果有新的配置项则返回 True"""
|
||||
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
|
||||
has_new = False
|
||||
|
||||
# 创建一个新的有序字典以保持参考配置的顺序
|
||||
new_conf = {}
|
||||
|
||||
# 先按照参考配置的顺序添加配置项
|
||||
for key, value in refer_conf.items():
|
||||
if key not in conf:
|
||||
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,已插入默认值 {value}")
|
||||
# 配置项不存在,插入默认值
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
|
||||
conf[key] = value
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
if conf[key] is None:
|
||||
conf[key] = value
|
||||
# 配置项为 None,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
elif isinstance(value, dict):
|
||||
has_new |= self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
# 递归检查子配置项
|
||||
if not isinstance(conf[key], dict):
|
||||
# 类型不匹配,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
# 递归检查并同步顺序
|
||||
child_has_new = self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
new_conf[key] = conf[key]
|
||||
has_new |= child_has_new
|
||||
else:
|
||||
# 直接使用现有配置
|
||||
new_conf[key] = conf[key]
|
||||
|
||||
# 检查是否存在参考配置中没有的配置项
|
||||
for key in list(conf.keys()):
|
||||
if key not in refer_conf:
|
||||
path_ = path + "." + key if path else key
|
||||
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
|
||||
has_new = True
|
||||
|
||||
# 顺序不一致也算作变更
|
||||
if list(conf.keys()) != list(new_conf.keys()):
|
||||
if path:
|
||||
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
|
||||
else:
|
||||
logger.info("检查到配置项顺序不一致,已重新排序")
|
||||
has_new = True
|
||||
|
||||
# 更新原始配置
|
||||
conf.clear()
|
||||
conf.update(new_conf)
|
||||
|
||||
return has_new
|
||||
|
||||
def save_config(self, replace_config: Dict = None):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,40 +5,44 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
|
||||
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.db.po import Conversation, ConversationV2
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.session_conversations: Dict[str, str] = {}
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
"""启动定时保存任务"""
|
||||
asyncio.create_task(self._periodic_save())
|
||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||
created_at = int(conv_v2.created_at.timestamp())
|
||||
updated_at = int(conv_v2.updated_at.timestamp())
|
||||
return Conversation(
|
||||
platform_id=conv_v2.platform_id,
|
||||
user_id=conv_v2.user_id,
|
||||
cid=conv_v2.conversation_id,
|
||||
history=json.dumps(conv_v2.content or []),
|
||||
title=conv_v2.title,
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
async def _periodic_save(self):
|
||||
"""定时保存会话对话映射关系到存储中"""
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
"""保存会话对话映射关系到存储中"""
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||
async def new_conversation(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
platform_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
) -> str:
|
||||
"""新建对话,并将当前会话的对话转移到新对话
|
||||
|
||||
Args:
|
||||
@@ -46,11 +50,23 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
return conversation_id
|
||||
if not platform_id:
|
||||
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
||||
parts = unified_msg_origin.split(":")
|
||||
if len(parts) >= 3:
|
||||
platform_id = parts[0]
|
||||
if not platform_id:
|
||||
platform_id = "unknown"
|
||||
conv = await self.db.create_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
platform_id=platform_id,
|
||||
content=content,
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
self.session_conversations[unified_msg_origin] = conv.conversation_id
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
|
||||
return conv.conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
"""切换会话的对话
|
||||
@@ -60,10 +76,10 @@ class ConversationManager:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
||||
|
||||
async def delete_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str = None
|
||||
self, unified_msg_origin: str, conversation_id: str | None = None
|
||||
):
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||
|
||||
@@ -71,13 +87,18 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
f = False
|
||||
if not conversation_id:
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
f = True
|
||||
if conversation_id:
|
||||
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
del self.session_conversations[unified_msg_origin]
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await self.db.delete_conversation(cid=conversation_id)
|
||||
if f:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
Args:
|
||||
@@ -85,11 +106,19 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
ret = self.session_conversations.get(unified_msg_origin, None)
|
||||
if not ret:
|
||||
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
|
||||
if ret:
|
||||
self.session_conversations[unified_msg_origin] = ret
|
||||
return ret
|
||||
|
||||
async def get_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str
|
||||
) -> Conversation:
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str,
|
||||
create_if_not_exists: bool = False,
|
||||
) -> Conversation | None:
|
||||
"""获取会话的对话
|
||||
|
||||
Args:
|
||||
@@ -98,20 +127,74 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation (Conversation): 对话对象
|
||||
"""
|
||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||
if not conv and create_if_not_exists:
|
||||
# 如果对话不存在且需要创建,则新建一个对话
|
||||
conversation_id = await self.new_conversation(unified_msg_origin)
|
||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||
conv_res = None
|
||||
if conv:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
return conv_res
|
||||
|
||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||
"""获取会话的所有对话
|
||||
async def get_conversations(
|
||||
self, unified_msg_origin: str | None = None, platform_id: str | None = None
|
||||
) -> List[Conversation]:
|
||||
"""获取对话列表
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
||||
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
||||
Returns:
|
||||
conversations (List[Conversation]): 对话对象列表
|
||||
"""
|
||||
return self.db.get_conversations(unified_msg_origin)
|
||||
convs = await self.db.get_conversations(
|
||||
user_id=unified_msg_origin, platform_id=platform_id
|
||||
)
|
||||
convs_res = []
|
||||
for conv in convs:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
convs_res.append(conv_res)
|
||||
return convs_res
|
||||
|
||||
async def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platform_ids: list[str] | None = None,
|
||||
search_query: str = "",
|
||||
**kwargs,
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""获取过滤后的对话列表
|
||||
|
||||
Args:
|
||||
page (int): 页码, 默认为 1
|
||||
page_size (int): 每页大小, 默认为 20
|
||||
platform_ids (list[str]): 平台 ID 列表, 可选
|
||||
search_query (str): 搜索查询字符串, 可选
|
||||
Returns:
|
||||
conversations (list[Conversation]): 对话对象列表
|
||||
"""
|
||||
convs, cnt = await self.db.get_filtered_conversations(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
platform_ids=platform_ids,
|
||||
search_query=search_query,
|
||||
**kwargs,
|
||||
)
|
||||
convs_res = []
|
||||
for conv in convs:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
convs_res.append(conv_res)
|
||||
return convs_res, cnt
|
||||
|
||||
async def update_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str | None = None,
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话
|
||||
|
||||
@@ -120,40 +203,55 @@ class ConversationManager:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
"""
|
||||
if not conversation_id:
|
||||
# 如果没有提供 conversation_id,则获取当前的
|
||||
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
await self.db.update_conversation(
|
||||
cid=conversation_id,
|
||||
history=json.dumps(history),
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
)
|
||||
|
||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||
async def update_conversation_title(
|
||||
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
|
||||
):
|
||||
"""更新会话的对话标题
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
title (str): 对话标题
|
||||
|
||||
Deprecated:
|
||||
Use `update_conversation` with `title` parameter instead.
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_title(
|
||||
user_id=unified_msg_origin, cid=conversation_id, title=title
|
||||
)
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
title=title,
|
||||
)
|
||||
|
||||
async def update_conversation_persona_id(
|
||||
self, unified_msg_origin: str, persona_id: str
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
persona_id: str,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话 Persona ID
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
persona_id (str): 对话 Persona ID
|
||||
|
||||
Deprecated:
|
||||
Use `update_conversation` with `persona_id` parameter instead.
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_persona_id(
|
||||
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
|
||||
)
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
async def get_human_readable_context(
|
||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
|
||||
工作流程:
|
||||
@@ -15,21 +15,23 @@ import time
|
||||
import threading
|
||||
import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from . import astrbot_config, html_renderer
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
|
||||
@@ -37,7 +39,7 @@ from astrbot.core.star.star_handler import star_map
|
||||
class AstrBotCoreLifecycle:
|
||||
"""
|
||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
EventBus 等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
"""
|
||||
@@ -47,14 +49,28 @@ class AstrBotCoreLifecycle:
|
||||
self.astrbot_config = astrbot_config # 初始化配置
|
||||
self.db = db # 初始化数据库
|
||||
|
||||
# 根据环境变量设置代理
|
||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||
os.environ["no_proxy"] = "localhost"
|
||||
# 设置代理
|
||||
proxy_config = self.astrbot_config.get("http_proxy", "")
|
||||
if proxy_config != "":
|
||||
os.environ["https_proxy"] = proxy_config
|
||||
os.environ["http_proxy"] = proxy_config
|
||||
logger.debug(f"Using proxy: {proxy_config}")
|
||||
# 设置 no_proxy
|
||||
no_proxy_list = self.astrbot_config.get("no_proxy", [])
|
||||
os.environ["no_proxy"] = ",".join(no_proxy_list)
|
||||
else:
|
||||
# 清空代理环境变量
|
||||
if "https_proxy" in os.environ:
|
||||
del os.environ["https_proxy"]
|
||||
if "http_proxy" in os.environ:
|
||||
del os.environ["http_proxy"]
|
||||
if "no_proxy" in os.environ:
|
||||
del os.environ["no_proxy"]
|
||||
logger.debug("HTTP proxy cleared")
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
|
||||
# 初始化日志代理
|
||||
@@ -64,21 +80,36 @@ class AstrBotCoreLifecycle:
|
||||
else:
|
||||
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||
|
||||
await self.db.initialize()
|
||||
|
||||
await html_renderer.initialize()
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
default_config=self.astrbot_config, sp=sp
|
||||
)
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
# 初始化人格管理器
|
||||
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
|
||||
await self.persona_mgr.initialize()
|
||||
|
||||
# 初始化供应商管理器
|
||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||
self.provider_manager = ProviderManager(
|
||||
self.astrbot_config_mgr, self.db, self.persona_mgr
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
# 初始化平台消息历史管理器
|
||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
@@ -87,7 +118,9 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager,
|
||||
self.platform_message_history_manager,
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -100,16 +133,16 @@ class AstrBotCoreLifecycle:
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler = PipelineScheduler(
|
||||
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||
)
|
||||
await self.pipeline_scheduler.initialize()
|
||||
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
# 初始化更新器
|
||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
|
||||
self.astrbot_updator = AstrBotUpdator()
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||
self.event_bus = EventBus(
|
||||
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
|
||||
)
|
||||
|
||||
# 记录启动时间
|
||||
self.start_time = int(time.time())
|
||||
@@ -226,6 +259,39 @@ class AstrBotCoreLifecycle:
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
for platform_inst in platform_insts:
|
||||
tasks.append(
|
||||
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
|
||||
asyncio.create_task(
|
||||
platform_inst.run(),
|
||||
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
|
||||
"""加载消息事件流水线调度器
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
"""
|
||||
mapping = {}
|
||||
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
)
|
||||
await scheduler.initialize()
|
||||
mapping[conf_id] = scheduler
|
||||
return mapping
|
||||
|
||||
async def reload_pipeline_scheduler(self, conf_id: str):
|
||||
"""重新加载消息事件流水线调度器
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
"""
|
||||
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
|
||||
if not ab_config:
|
||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
)
|
||||
await scheduler.initialize()
|
||||
self.pipeline_scheduler_mapping[conf_id] = scheduler
|
||||
|
||||
@@ -1,7 +1,20 @@
|
||||
import abc
|
||||
import datetime
|
||||
import typing as T
|
||||
from deprecated import deprecated
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||
from astrbot.core.db.po import (
|
||||
Stats,
|
||||
PlatformStat,
|
||||
ConversationV2,
|
||||
PlatformMessageHistory,
|
||||
Attachment,
|
||||
Persona,
|
||||
Preference,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -10,152 +23,262 @@ class BaseDatabase(abc.ABC):
|
||||
数据库基类
|
||||
"""
|
||||
|
||||
DATABASE_URL = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化数据库连接"""
|
||||
pass
|
||||
|
||||
def insert_base_metrics(self, metrics: dict):
|
||||
"""插入基础指标数据"""
|
||||
self.insert_platform_metrics(metrics["platform_stats"])
|
||||
self.insert_plugin_metrics(metrics["plugin_stats"])
|
||||
self.insert_command_metrics(metrics["command_stats"])
|
||||
self.insert_llm_metrics(metrics["llm_stats"])
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
"""插入平台指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_plugin_metrics(self, metrics: dict):
|
||||
"""插入插件指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_command_metrics(self, metrics: dict):
|
||||
"""插入指令指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
"""插入 LLM 指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_llm_history(
|
||||
self, session_id: str = None, provider_type: str = None
|
||||
) -> List[LLMHistory]:
|
||||
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
|
||||
raise NotImplementedError
|
||||
@asynccontextmanager
|
||||
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session."""
|
||||
if not self.inited:
|
||||
await self.initialize()
|
||||
self.inited = True
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取基础统计数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_total_message_count(self) -> int:
|
||||
"""获取总消息数"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取基础统计数据(合并)"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_atri_vision_data(self, vision_data: ATRIVision):
|
||||
"""插入 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
# New methods in v4.0.0
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data(self) -> List[ATRIVision]:
|
||||
"""获取 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
async def insert_platform_stats(
|
||||
self,
|
||||
platform_id: str,
|
||||
platform_type: str,
|
||||
count: int = 1,
|
||||
timestamp: datetime.datetime | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data_by_path_or_id(
|
||||
self, url_or_path: str, id: str
|
||||
) -> ATRIVision:
|
||||
"""通过 url 或 path 获取 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
async def count_platform_stats(self) -> int:
|
||||
"""Count the number of platform statistics records."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
"""通过 user_id 和 cid 获取 Conversation"""
|
||||
raise NotImplementedError
|
||||
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
"""新建 Conversation"""
|
||||
raise NotImplementedError
|
||||
async def get_conversations(
|
||||
self, user_id: str | None = None, platform_id: str | None = None
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations for a specific user and platform_id(optional).
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversations(self, user_id: str) -> List[Conversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
"""更新 Conversation"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
"""删除 Conversation"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
"""更新 Conversation 标题"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
"""更新 Conversation Persona ID"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取所有对话,支持分页
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||
content is not included in the result.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_filtered_conversations(
|
||||
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
|
||||
"""Get a specific conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations with pagination."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platforms: List[str] = None,
|
||||
message_types: List[str] = None,
|
||||
search_query: str = None,
|
||||
exclude_ids: List[str] = None,
|
||||
exclude_platforms: List[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取筛选后的对话列表
|
||||
platform_ids: list[str] | None = None,
|
||||
search_query: str = "",
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationV2], int]:
|
||||
"""Get conversations filtered by platform IDs and search query."""
|
||||
...
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
platforms: 平台筛选列表
|
||||
message_types: 消息类型筛选列表
|
||||
search_query: 搜索关键词
|
||||
exclude_ids: 排除的用户ID列表
|
||||
exclude_platforms: 排除的平台列表
|
||||
@abc.abstractmethod
|
||||
async def create_conversation(
|
||||
self,
|
||||
user_id: str,
|
||||
platform_id: str,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
cid: str | None = None,
|
||||
created_at: datetime.datetime | None = None,
|
||||
updated_at: datetime.datetime | None = None,
|
||||
) -> ConversationV2:
|
||||
"""Create a new conversation."""
|
||||
...
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@abc.abstractmethod
|
||||
async def update_conversation(
|
||||
self,
|
||||
cid: str,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_conversation(self, cid: str) -> None:
|
||||
"""Delete a conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict],
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform message history record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_platform_message_offset(
|
||||
self, platform_id: str, user_id: str, offset_sec: int = 86400
|
||||
) -> None:
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformMessageHistory]:
|
||||
"""Get platform message history for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_attachment(
|
||||
self,
|
||||
path: str,
|
||||
type: str,
|
||||
mime_type: str,
|
||||
):
|
||||
"""Insert a new attachment record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
|
||||
"""Get an attachment by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""Insert a new persona record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_persona_by_id(self, persona_id: str) -> Persona:
|
||||
"""Get a persona by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_personas(self) -> list[Persona]:
|
||||
"""Get all personas for a specific bot."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str | None = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona | None:
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_persona(self, persona_id: str) -> None:
|
||||
"""Delete a persona by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_preference_or_update(
|
||||
self, scope: str, scope_id: str, key: str, value: dict
|
||||
) -> Preference:
|
||||
"""Insert a new preference record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
|
||||
"""Get a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preferences(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""Get all preferences for a specific scope ID or key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
|
||||
"""Remove a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear_preferences(self, scope: str, scope_id: str) -> None:
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
# cid: str,
|
||||
# role: str,
|
||||
# content: list,
|
||||
# tool_calls: list = None,
|
||||
# tool_call_id: str = None,
|
||||
# parent_id: str = None,
|
||||
# ) -> LLMMessage:
|
||||
# """Insert a new LLM message into the conversation."""
|
||||
# ...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
||||
# """Get all LLM messages for a specific conversation."""
|
||||
# ...
|
||||
|
||||
64
astrbot/core/db/migration/helper.py
Normal file
64
astrbot/core/db/migration/helper.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.api import logger, sp
|
||||
from .migra_3_to_4 import (
|
||||
migration_conversation_table,
|
||||
migration_platform_table,
|
||||
migration_webchat_data,
|
||||
migration_persona_data,
|
||||
migration_preferences,
|
||||
)
|
||||
|
||||
|
||||
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
"""
|
||||
检查是否需要进行数据库迁移
|
||||
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。
|
||||
"""
|
||||
data_v3_exists = os.path.exists(get_astrbot_data_path())
|
||||
if not data_v3_exists:
|
||||
return False
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_v4"
|
||||
)
|
||||
if migration_done:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def do_migration_v4(
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
astrbot_config: AstrBotConfig,
|
||||
):
|
||||
"""
|
||||
执行数据库迁移
|
||||
迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
迁移旧的 platform 到新的 platform_stats 表。
|
||||
"""
|
||||
if not await check_migration_needed_v4(db_helper):
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移...")
|
||||
|
||||
# 执行会话表迁移
|
||||
await migration_conversation_table(db_helper, platform_id_map)
|
||||
|
||||
# 执行人格数据迁移
|
||||
await migration_persona_data(db_helper, astrbot_config)
|
||||
|
||||
# 执行 WebChat 数据迁移
|
||||
await migration_webchat_data(db_helper, platform_id_map)
|
||||
|
||||
# 执行偏好设置迁移
|
||||
await migration_preferences(db_helper,platform_id_map)
|
||||
|
||||
# 执行平台统计表迁移
|
||||
await migration_platform_table(db_helper, platform_id_map)
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_v4", True)
|
||||
|
||||
logger.info("数据库迁移完成。")
|
||||
338
astrbot/core/db/migration/migra_3_to_4.py
Normal file
338
astrbot/core/db/migration/migra_3_to_4.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import json
|
||||
import datetime
|
||||
from .. import BaseDatabase
|
||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||
from .shared_preferences_v3 import sp as sp_v3
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
|
||||
from sqlalchemy import text
|
||||
|
||||
"""
|
||||
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
2. 迁移旧的 platform 到新的 platform_stats 表。
|
||||
"""
|
||||
|
||||
|
||||
def get_platform_id(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
{"platform_id": old_platform_name, "platform_type": old_platform_name},
|
||||
).get("platform_id", old_platform_name)
|
||||
|
||||
|
||||
def get_platform_type(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
{"platform_id": old_platform_name, "platform_type": old_platform_name},
|
||||
).get("platform_type", old_platform_name)
|
||||
|
||||
|
||||
async def migration_conversation_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, session.platform_name
|
||||
)
|
||||
session.platform_id = platform_id # 更新平台名称为新的 ID
|
||||
conv_v2 = ConversationV2(
|
||||
user_id=str(session),
|
||||
content=json.loads(conv.history) if conv.history else [],
|
||||
platform_id=platform_id,
|
||||
title=conv.title,
|
||||
persona_id=conv.persona_id,
|
||||
conversation_id=conv.cid,
|
||||
created_at=datetime.datetime.fromtimestamp(conv.created_at),
|
||||
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
|
||||
)
|
||||
dbsession.add(conv_v2)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_platform_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
secs_from_2023_4_10_to_now = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
|
||||
).total_seconds()
|
||||
offset_sec = int(secs_from_2023_4_10_to_now)
|
||||
logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。")
|
||||
stats = db_helper_v3.get_base_stats(offset_sec=offset_sec)
|
||||
logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...")
|
||||
platform_stats_v3 = stats.platform
|
||||
|
||||
if not platform_stats_v3:
|
||||
logger.info("没有找到旧平台数据,跳过迁移。")
|
||||
return
|
||||
|
||||
first_time_stamp = platform_stats_v3[0].timestamp
|
||||
end_time_stamp = platform_stats_v3[-1].timestamp
|
||||
start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时
|
||||
end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时
|
||||
|
||||
idx = 0
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
total_buckets = (end_time - start_time) // 3600
|
||||
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
|
||||
if bucket_idx % 500 == 0:
|
||||
progress = int((bucket_idx + 1) / total_buckets * 100)
|
||||
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(platform_stats_v3)
|
||||
and platform_stats_v3[idx].timestamp < bucket_end
|
||||
):
|
||||
cnt += platform_stats_v3[idx].count
|
||||
idx += 1
|
||||
if cnt == 0:
|
||||
continue
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
)
|
||||
platform_type = get_platform_type(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
)
|
||||
try:
|
||||
await dbsession.execute(
|
||||
text("""
|
||||
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
|
||||
VALUES (:timestamp, :platform_id, :platform_type, :count)
|
||||
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
|
||||
count = platform_stats.count + EXCLUDED.count
|
||||
"""),
|
||||
{
|
||||
"timestamp": datetime.datetime.fromtimestamp(
|
||||
bucket_end, tz=datetime.timezone.utc
|
||||
),
|
||||
"platform_id": platform_id,
|
||||
"platform_type": platform_type,
|
||||
"count": cnt,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
|
||||
|
||||
|
||||
async def migration_webchat_data(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
history = json.loads(conv.history) if conv.history else []
|
||||
for msg in history:
|
||||
type_ = msg.get("type") # user type, "bot" or "user"
|
||||
new_history = PlatformMessageHistory(
|
||||
platform_id=platform_id,
|
||||
user_id=conv.cid, # we use conv.cid as user_id for webchat
|
||||
content=msg,
|
||||
sender_id=type_,
|
||||
sender_name=type_,
|
||||
)
|
||||
dbsession.add(new_history)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_persona_data(
|
||||
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
|
||||
):
|
||||
"""
|
||||
迁移 Persona 数据到新的表中。
|
||||
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
||||
"""
|
||||
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
|
||||
total_personas = len(v3_persona_config)
|
||||
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
|
||||
|
||||
for idx, persona in enumerate(v3_persona_config):
|
||||
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
|
||||
progress = int((idx + 1) / total_personas * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
|
||||
try:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
mood_prompt = ""
|
||||
user_turn = True
|
||||
for mood_dialog in mood_imitation_dialogs:
|
||||
if user_turn:
|
||||
mood_prompt += f"A: {mood_dialog}\n"
|
||||
else:
|
||||
mood_prompt += f"B: {mood_dialog}\n"
|
||||
user_turn = not user_turn
|
||||
system_prompt = persona.get("prompt", "")
|
||||
if mood_prompt:
|
||||
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
|
||||
persona_new = await db_helper.insert_persona(
|
||||
persona_id=persona["name"],
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs,
|
||||
)
|
||||
logger.info(
|
||||
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
|
||||
async def migration_preferences(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
# 1. global scope migration
|
||||
keys = [
|
||||
"inactivated_llm_tools",
|
||||
"inactivated_plugins",
|
||||
"curr_provider",
|
||||
"curr_provider_tts",
|
||||
"curr_provider_stt",
|
||||
"alter_cmd",
|
||||
]
|
||||
for key in keys:
|
||||
value = sp_v3.get(key)
|
||||
if value is not None:
|
||||
await sp.put_async("global", "global", key, value)
|
||||
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
|
||||
|
||||
# 2. umo scope migration
|
||||
session_conversation = sp_v3.get("session_conversation", default={})
|
||||
for umo, conversation_id in session_conversation.items():
|
||||
if not umo or not conversation_id:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
|
||||
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
|
||||
|
||||
session_service_config = sp_v3.get("session_service_config", default={})
|
||||
for umo, config in session_service_config.items():
|
||||
if not umo or not config:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
await sp.put_async("umo", str(session), "session_service_config", config)
|
||||
|
||||
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
|
||||
|
||||
session_variables = sp_v3.get("session_variables", default={})
|
||||
for umo, variables in session_variables.items():
|
||||
if not umo or not variables:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "session_variables", variables)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
|
||||
|
||||
session_provider_perf = sp_v3.get("session_provider_perf", default={})
|
||||
for umo, perf in session_provider_perf.items():
|
||||
if not umo or not perf:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
for provider_type, provider_id in perf.items():
|
||||
await sp.put_async(
|
||||
"umo", str(session), f"provider_perf_{provider_type}", provider_id
|
||||
)
|
||||
logger.info(
|
||||
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
|
||||
45
astrbot/core/db/migration/shared_preferences_v3.py
Normal file
45
astrbot/core/db/migration/shared_preferences_v3.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
|
||||
sp = SharedPreferences()
|
||||
493
astrbot/core/db/migration/sqlite_v3.py
Normal file
493
astrbot/core/db/migration/sqlite_v3.py
Normal file
@@ -0,0 +1,493 @@
|
||||
import sqlite3
|
||||
import time
|
||||
from astrbot.core.db.po import Platform, Stats
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
"""字符串格式的列表。"""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
|
||||
|
||||
INIT_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS platform(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS plugin(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS command(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
provider_type VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
|
||||
-- ATRI
|
||||
CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
id TEXT,
|
||||
url_or_path TEXT,
|
||||
caption TEXT,
|
||||
is_meme BOOLEAN,
|
||||
keywords TEXT,
|
||||
platform_name VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT, -- 会话 id
|
||||
cid TEXT, -- 对话 id
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
|
||||
PRAGMA encoding = 'UTF-8';
|
||||
"""
|
||||
|
||||
|
||||
class SQLiteDatabase():
|
||||
def __init__(self, db_path: str) -> None:
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
sql = INIT_SQL
|
||||
|
||||
# 初始化数据库
|
||||
self.conn = self._get_conn(self.db_path)
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
# 检查 webchat_conversation 的 title 字段是否存在
|
||||
c.execute(
|
||||
"""
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
"""
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
has_persona_id = False
|
||||
for row in res:
|
||||
if row[1] == "title":
|
||||
has_title = True
|
||||
if row[1] == "persona_id":
|
||||
has_persona_id = True
|
||||
if not has_title:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
c.close()
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: Tuple = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
conn = self._get_conn(self.db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
if params:
|
||||
c.execute(sql, params)
|
||||
c.close()
|
||||
else:
|
||||
c.execute(sql)
|
||||
c.close()
|
||||
|
||||
conn.commit()
|
||||
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT * FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform=platform)
|
||||
|
||||
def get_total_message_count(self) -> int:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT SUM(count) FROM platform
|
||||
"""
|
||||
)
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return res[0]
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT name, SUM(count), timestamp FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
+ " GROUP BY name"
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
|
||||
if not res:
|
||||
return
|
||||
|
||||
return Conversation(*res)
|
||||
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
res = c.fetchall()
|
||||
c.close()
|
||||
conversations = []
|
||||
for row in res:
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(
|
||||
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
||||
)
|
||||
return conversations
|
||||
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
"""更新对话,并且同时更新时间"""
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(history, updated_at, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(title, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(persona_id, user_id, cid),
|
||||
)
|
||||
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
try:
|
||||
# 获取总记录数
|
||||
c.execute("""
|
||||
SELECT COUNT(*) FROM webchat_conversation
|
||||
""")
|
||||
total_count = c.fetchone()[0]
|
||||
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取分页数据,按更新时间降序排序
|
||||
c.execute(
|
||||
"""
|
||||
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||
FROM webchat_conversation
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(page_size, offset),
|
||||
)
|
||||
|
||||
rows = c.fetchall()
|
||||
|
||||
conversations = []
|
||||
|
||||
for row in rows:
|
||||
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
|
||||
safe_cid = str(cid) if cid else "unknown"
|
||||
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||
|
||||
conversations.append(
|
||||
{
|
||||
"user_id": user_id or "",
|
||||
"cid": safe_cid,
|
||||
"title": title or f"对话 {display_cid}",
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
|
||||
except Exception as _:
|
||||
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||
return [], 0
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platforms: List[str] = None,
|
||||
message_types: List[str] = None,
|
||||
search_query: str = None,
|
||||
exclude_ids: List[str] = None,
|
||||
exclude_platforms: List[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取筛选后的对话列表"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
where_clauses = []
|
||||
params = []
|
||||
|
||||
# 平台筛选
|
||||
if platforms and len(platforms) > 0:
|
||||
platform_conditions = []
|
||||
for platform in platforms:
|
||||
platform_conditions.append("user_id LIKE ?")
|
||||
params.append(f"{platform}:%")
|
||||
|
||||
if platform_conditions:
|
||||
where_clauses.append(f"({' OR '.join(platform_conditions)})")
|
||||
|
||||
# 消息类型筛选
|
||||
if message_types and len(message_types) > 0:
|
||||
message_type_conditions = []
|
||||
for msg_type in message_types:
|
||||
message_type_conditions.append("user_id LIKE ?")
|
||||
params.append(f"%:{msg_type}:%")
|
||||
|
||||
if message_type_conditions:
|
||||
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
|
||||
|
||||
# 搜索关键词
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
where_clauses.append(
|
||||
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
||||
)
|
||||
search_param = f"%{search_query}%"
|
||||
params.extend([search_param, search_param, search_param, search_param])
|
||||
|
||||
# 排除特定用户ID
|
||||
if exclude_ids and len(exclude_ids) > 0:
|
||||
for exclude_id in exclude_ids:
|
||||
where_clauses.append("user_id NOT LIKE ?")
|
||||
params.append(f"{exclude_id}%")
|
||||
|
||||
# 排除特定平台
|
||||
if exclude_platforms and len(exclude_platforms) > 0:
|
||||
for exclude_platform in exclude_platforms:
|
||||
where_clauses.append("user_id NOT LIKE ?")
|
||||
params.append(f"{exclude_platform}:%")
|
||||
|
||||
# 构建完整的 WHERE 子句
|
||||
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||
|
||||
# 构建计数查询
|
||||
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
|
||||
|
||||
# 获取总记录数
|
||||
c.execute(count_sql, params)
|
||||
total_count = c.fetchone()[0]
|
||||
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 构建分页数据查询
|
||||
data_sql = f"""
|
||||
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||
FROM webchat_conversation
|
||||
{where_sql}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
query_params = params + [page_size, offset]
|
||||
|
||||
# 获取分页数据
|
||||
c.execute(data_sql, query_params)
|
||||
rows = c.fetchall()
|
||||
|
||||
conversations = []
|
||||
|
||||
for row in rows:
|
||||
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||
# 确保 cid 是字符串类型,否则使用一个默认值
|
||||
safe_cid = str(cid) if cid else "unknown"
|
||||
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||
|
||||
conversations.append(
|
||||
{
|
||||
"user_id": user_id or "",
|
||||
"cid": safe_cid,
|
||||
"title": title or f"对话 {display_cid}",
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
|
||||
except Exception as _:
|
||||
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||
return [], 0
|
||||
finally:
|
||||
c.close()
|
||||
@@ -1,112 +0,0 @@
|
||||
import json
|
||||
import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
|
||||
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
"""插件数据的 SQLite 存储实现类。
|
||||
|
||||
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||
"""
|
||||
|
||||
_instance = None # Standalone instance of the class
|
||||
_db_conn = None
|
||||
db_path = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
创建或获取 SQLitePluginStorage 的单例实例。
|
||||
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||
cls._instance.db_path = DBPATH
|
||||
return cls._instance
|
||||
|
||||
async def _init_db(self):
|
||||
"""初始化数据库连接(只执行一次)"""
|
||||
if SQLitePluginStorage._db_conn is None:
|
||||
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||
await self._setup_db()
|
||||
|
||||
async def _setup_db(self):
|
||||
"""
|
||||
异步初始化数据库。
|
||||
|
||||
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||
其中 plugin 和 key 组合作为主键。
|
||||
"""
|
||||
await self._db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||
plugin TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
PRIMARY KEY (plugin, key)
|
||||
)
|
||||
""")
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def set(self, plugin: str, key: str, value: Any):
|
||||
"""
|
||||
异步存储数据。
|
||||
|
||||
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||
值会被序列化为 JSON 字符串后存储。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, json.dumps(value)),
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def get(self, plugin: str, key: str) -> Any:
|
||||
"""
|
||||
异步获取数据。
|
||||
|
||||
从数据库中获取指定插件和键名对应的值,
|
||||
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
|
||||
Returns:
|
||||
Any: 存储的数据值,如果未找到则返回 None
|
||||
"""
|
||||
await self._init_db()
|
||||
async with self._db_conn.execute(
|
||||
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
|
||||
async def delete(self, plugin: str, key: str):
|
||||
"""
|
||||
异步删除数据。
|
||||
|
||||
从数据库中删除指定插件和键名对应的数据项。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 要删除的数据键名
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
@@ -1,7 +1,233 @@
|
||||
"""指标数据"""
|
||||
import uuid
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
from sqlmodel import (
|
||||
SQLModel,
|
||||
Text,
|
||||
JSON,
|
||||
UniqueConstraint,
|
||||
Field,
|
||||
)
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
|
||||
class PlatformStat(SQLModel, table=True):
|
||||
"""This class represents the statistics of bot usage across different platforms.
|
||||
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
platform_id: str = Field(nullable=False)
|
||||
platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc.
|
||||
count: int = Field(default=0, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"timestamp",
|
||||
"platform_id",
|
||||
"platform_type",
|
||||
name="uix_platform_stats",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conversation_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False)
|
||||
content: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
title: Optional[str] = Field(default=None, max_length=255)
|
||||
persona_id: Optional[str] = Field(default=None)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conversation_id",
|
||||
name="uix_conversation_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Persona(SQLModel, table=True):
|
||||
"""Persona is a set of instructions for LLMs to follow.
|
||||
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
persona_id: str = Field(max_length=255, nullable=False)
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""a list of strings, each representing a dialog to start with"""
|
||||
tools: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"persona_id",
|
||||
name="uix_persona_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
scope: str = Field(nullable=False)
|
||||
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
||||
scope_id: str = Field(nullable=False)
|
||||
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
|
||||
key: str = Field(nullable=False)
|
||||
value: dict = Field(sa_type=JSON, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"scope",
|
||||
"scope_id",
|
||||
"key",
|
||||
name="uix_preference_scope_scope_id_key",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PlatformMessageHistory(SQLModel, table=True):
|
||||
"""This class represents the message history for a specific platform.
|
||||
|
||||
It is used to store messages that are not LLM-generated, such as user messages
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
sender_name: Optional[str] = Field(
|
||||
default=None
|
||||
) # Name of the sender in the platform
|
||||
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class Attachment(SQLModel, table=True):
|
||||
"""This class represents attachments for messages in AstrBot.
|
||||
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments"
|
||||
|
||||
inner_attachment_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
path: str = Field(nullable=False) # Path to the file on disk
|
||||
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
|
||||
mime_type: str = Field(nullable=False) # MIME type of the file
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"attachment_id",
|
||||
name="uix_attachment_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
|
||||
对于 WebChat,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
|
||||
在 v4.0.0 版本及之后,WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中,
|
||||
"""
|
||||
|
||||
platform_id: str
|
||||
user_id: str
|
||||
cid: str
|
||||
"""对话 ID, 是 uuid 格式的字符串"""
|
||||
history: str = ""
|
||||
"""字符串格式的对话列表。"""
|
||||
title: str | None = ""
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
"""LLM 人格类。
|
||||
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
# ====
|
||||
# Deprecated, and will be removed in future versions.
|
||||
# ====
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -13,77 +239,6 @@ class Platform:
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""供应商使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plugin:
|
||||
"""插件使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Command:
|
||||
"""命令使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
platform: List[Platform] = field(default_factory=list)
|
||||
command: List[Command] = field(default_factory=list)
|
||||
llm: List[Provider] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMHistory:
|
||||
"""LLM 聊天时持久化的信息"""
|
||||
|
||||
provider_type: str
|
||||
session_id: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ATRIVision:
|
||||
"""Deprecated"""
|
||||
|
||||
id: str
|
||||
url_or_path: str
|
||||
caption: str
|
||||
is_meme: bool
|
||||
keywords: List[str]
|
||||
platform_name: str
|
||||
session_id: str
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
"""字符串格式的列表。"""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
platform: list[Platform] = field(default_factory=list)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,50 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS platform(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS plugin(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS command(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
provider_type VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
|
||||
-- ATRI
|
||||
CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
id TEXT,
|
||||
url_or_path TEXT,
|
||||
caption TEXT,
|
||||
is_meme BOOLEAN,
|
||||
keywords TEXT,
|
||||
platform_name VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT, -- 会话 id
|
||||
cid TEXT, -- 对话 id
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
|
||||
PRAGMA encoding = 'UTF-8';
|
||||
46
astrbot/core/db/vec_db/base.py
Normal file
46
astrbot/core/db/vec_db/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
similarity: float
|
||||
data: dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化向量数据库
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
top_k (int): 返回的最相似文档的数量
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, doc_id: str) -> bool:
|
||||
"""
|
||||
删除指定文档。
|
||||
Args:
|
||||
doc_id (str): 要删除的文档 ID
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
|
||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||
"""Retrieve documents by metadata filters and ids.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
"""
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
)
|
||||
await self.connection.commit()
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
Args:
|
||||
row (tuple): The row to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
"doc_id": row[1],
|
||||
"text": row[2],
|
||||
"metadata": row[3],
|
||||
"created_at": row[4],
|
||||
"updated_at": row[5],
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
try:
|
||||
import faiss
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
||||
)
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str = None):
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
if path and os.path.exists(path):
|
||||
self.index = faiss.read_index(path)
|
||||
else:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
self.storage = {}
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
"""插入向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 要插入的向量
|
||||
id (int): 向量的ID
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
self.storage[id] = vector
|
||||
await self.save_index()
|
||||
|
||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||
"""搜索最相似的向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 查询向量
|
||||
k (int): 返回的最相似向量的数量
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
"""
|
||||
faiss.normalize_L2(vector)
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引
|
||||
|
||||
Args:
|
||||
path (str): 保存索引的路径
|
||||
"""
|
||||
faiss.write_index(self.index, self.path)
|
||||
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
-- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at
|
||||
CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
doc_id TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
|
||||
|
||||
CREATE INDEX idx_documents_user_id ON documents(user_id);
|
||||
CREATE INDEX idx_documents_group_id ON documents(group_id);
|
||||
140
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
140
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import uuid
|
||||
import json
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
A class to represent a vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_store_path: str,
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
rerank_provider: RerankProvider | None = None,
|
||||
):
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
self.embedding_provider = embedding_provider
|
||||
self.document_storage = DocumentStorage(doc_store_path)
|
||||
self.embedding_storage = EmbeddingStorage(
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
self.rerank_provider = rerank_provider
|
||||
|
||||
async def initialize(self):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
k (int): 返回的最相似文档的数量
|
||||
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||
rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。
|
||||
metadata_filters (dict): 元数据过滤器
|
||||
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
embedding = await self.embedding_provider.get_embedding(query)
|
||||
scores, indices = await self.embedding_storage.search(
|
||||
vector=np.array([embedding]).astype("float32"),
|
||||
k=fetch_k if metadata_filters else k,
|
||||
)
|
||||
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||
return []
|
||||
# normalize scores
|
||||
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||
# NOTE: maybe the size is less than k.
|
||||
fetched_docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters or {}, ids=indices[0]
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
result_docs: list[Result] = []
|
||||
|
||||
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||
for i, indice_idx in enumerate(indices[0]):
|
||||
pos = idx_pos.get(indice_idx)
|
||||
if pos is None:
|
||||
continue
|
||||
fetch_doc = fetched_docs[pos]
|
||||
score = scores[0][i]
|
||||
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||
|
||||
top_k_results = result_docs[:k]
|
||||
|
||||
if rerank and self.rerank_provider:
|
||||
documents = [doc.data["text"] for doc in top_k_results]
|
||||
reranked_results = await self.rerank_provider.rerank(query, documents)
|
||||
reranked_results = sorted(
|
||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
||||
)
|
||||
top_k_results = [
|
||||
top_k_results[reranked_result.index] for reranked_result in reranked_results
|
||||
]
|
||||
|
||||
return top_k_results
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
"""
|
||||
删除一条文档
|
||||
"""
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
"""
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
@@ -16,30 +16,32 @@ from asyncio import Queue
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||
from astrbot.core import logger
|
||||
from .platform import AstrMessageEvent
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""事件总线: 用于处理事件的分发和处理
|
||||
"""用于处理事件的分发和处理"""
|
||||
|
||||
维护一个异步队列, 来接受各种消息事件
|
||||
"""
|
||||
|
||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
|
||||
# abconf uuid -> scheduler
|
||||
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
async def dispatch(self):
|
||||
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
|
||||
while True:
|
||||
event: AstrMessageEvent = (
|
||||
await self.event_queue.get()
|
||||
) # 从事件队列中获取新的事件
|
||||
self._print_event(event) # 打印日志
|
||||
asyncio.create_task(
|
||||
self.pipeline_scheduler.execute(event)
|
||||
) # 创建新的异步任务来执行管道调度器的处理逻辑
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent):
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
"""用于记录事件信息
|
||||
|
||||
Args:
|
||||
@@ -48,10 +50,10 @@ class EventBus:
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
if event.get_sender_name():
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
)
|
||||
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||
else:
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
)
|
||||
|
||||
92
astrbot/core/file_token_service.py
Normal file
92
astrbot/core/file_token_service.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from urllib.parse import urlparse, unquote
|
||||
import platform
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||
|
||||
def __init__(self, default_timeout: float = 300):
|
||||
self.lock = asyncio.Lock()
|
||||
self.staged_files = {} # token: (file_path, expire_time)
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
async def _cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
now = time.time()
|
||||
expired_tokens = [
|
||||
token for token, (_, expire) in self.staged_files.items() if expire < now
|
||||
]
|
||||
for token in expired_tokens:
|
||||
self.staged_files.pop(token, None)
|
||||
|
||||
async def register_file(self, file_path: str, timeout: float = None) -> str:
|
||||
"""向令牌服务注册一个文件。
|
||||
|
||||
Args:
|
||||
file_path(str): 文件路径
|
||||
timeout(float): 超时时间,单位秒(可选)
|
||||
|
||||
Returns:
|
||||
str: 一个单次令牌
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 当路径不存在时抛出
|
||||
"""
|
||||
|
||||
# 处理 file:///
|
||||
try:
|
||||
parsed_uri = urlparse(file_path)
|
||||
if parsed_uri.scheme == "file":
|
||||
local_path = unquote(parsed_uri.path)
|
||||
if platform.system() == "Windows" and local_path.startswith("/"):
|
||||
local_path = local_path[1:]
|
||||
else:
|
||||
# 如果没有 file:/// 前缀,则认为是普通路径
|
||||
local_path = file_path
|
||||
except Exception:
|
||||
# 解析失败时,按原路径处理
|
||||
local_path = file_path
|
||||
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在: {local_path} (原始输入: {file_path})"
|
||||
)
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
expire_time = time.time() + (
|
||||
timeout if timeout is not None else self.default_timeout
|
||||
)
|
||||
# 存储转换后的真实路径
|
||||
self.staged_files[file_token] = (local_path, expire_time)
|
||||
return file_token
|
||||
|
||||
async def handle_file(self, file_token: str) -> str:
|
||||
"""根据令牌获取文件路径,使用后令牌失效。
|
||||
|
||||
Args:
|
||||
file_token(str): 注册时返回的令牌
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
|
||||
Raises:
|
||||
KeyError: 当令牌不存在或已过期时抛出
|
||||
FileNotFoundError: 当文件本身已被删除时抛出
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if file_token not in self.staged_files:
|
||||
raise KeyError(f"无效或过期的文件 token: {file_token}")
|
||||
|
||||
file_path, _ = self.staged_files.pop(file_token)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
return file_path
|
||||
@@ -26,13 +26,14 @@ class InitialLoader:
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
|
||||
core_task = []
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
except Exception as e:
|
||||
logger.critical(traceback.format_exc())
|
||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||
return
|
||||
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||
|
||||
@@ -25,6 +25,7 @@ import logging
|
||||
import colorlog
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
@@ -95,8 +96,6 @@ class LogBroker:
|
||||
Queue: 订阅者的队列, 可用于接收日志消息
|
||||
"""
|
||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||
for log in self.log_cache:
|
||||
q.put_nowait(log)
|
||||
self.subscribers.append(q)
|
||||
return q
|
||||
|
||||
@@ -171,7 +170,9 @@ class LogManager:
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
# 如果logger没有处理器
|
||||
console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出
|
||||
console_handler = logging.StreamHandler(
|
||||
sys.stdout
|
||||
) # 创建一个StreamHandler用于控制台输出
|
||||
console_handler.setLevel(
|
||||
logging.DEBUG
|
||||
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||
|
||||
@@ -22,17 +22,22 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import typing as T
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from pydantic.v1 import BaseModel
|
||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
||||
|
||||
from astrbot.core import astrbot_config, file_token_service, logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||
|
||||
|
||||
class ComponentType(Enum):
|
||||
class ComponentType(str, Enum):
|
||||
Plain = "Plain" # 纯文本消息
|
||||
Face = "Face" # QQ表情
|
||||
Record = "Record" # 语音
|
||||
@@ -97,9 +102,13 @@ class BaseMessageComponent(BaseModel):
|
||||
data[k] = v
|
||||
return {"type": self.type.lower(), "data": data}
|
||||
|
||||
async def to_dict(self) -> dict:
|
||||
# 默认情况下,回退到旧的同步 toDict()
|
||||
return self.toDict()
|
||||
|
||||
|
||||
class Plain(BaseMessageComponent):
|
||||
type: ComponentType = "Plain"
|
||||
type = ComponentType.Plain
|
||||
text: str
|
||||
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
||||
|
||||
@@ -113,9 +122,15 @@ class Plain(BaseMessageComponent):
|
||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||
)
|
||||
|
||||
def toDict(self):
|
||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||
|
||||
async def to_dict(self):
|
||||
return {"type": "text", "data": {"text": self.text}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
type = ComponentType.Face
|
||||
id: int
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -123,7 +138,7 @@ class Face(BaseMessageComponent):
|
||||
|
||||
|
||||
class Record(BaseMessageComponent):
|
||||
type: ComponentType = "Record"
|
||||
type = ComponentType.Record
|
||||
file: T.Optional[str] = ""
|
||||
magic: T.Optional[bool] = False
|
||||
url: T.Optional[str] = ""
|
||||
@@ -150,28 +165,33 @@ class Record(BaseMessageComponent):
|
||||
return Record(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
@staticmethod
|
||||
def fromBase64(bs64_data: str, **_):
|
||||
return Record(file=f"base64://{bs64_data}", **_)
|
||||
|
||||
async def convert_to_file_path(self) -> str:
|
||||
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||
|
||||
Returns:
|
||||
str: 语音的本地路径,以绝对路径表示。
|
||||
"""
|
||||
if self.file and self.file.startswith("file:///"):
|
||||
file_path = self.file[8:]
|
||||
return file_path
|
||||
elif self.file and self.file.startswith("http"):
|
||||
if not self.file:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
if self.file.startswith("file:///"):
|
||||
return self.file[8:]
|
||||
elif self.file.startswith("http"):
|
||||
file_path = await download_image_by_url(self.file)
|
||||
return os.path.abspath(file_path)
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
elif self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
elif os.path.exists(self.file):
|
||||
file_path = self.file
|
||||
return os.path.abspath(file_path)
|
||||
return os.path.abspath(self.file)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
|
||||
@@ -182,22 +202,48 @@ class Record(BaseMessageComponent):
|
||||
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||
"""
|
||||
# convert to base64
|
||||
if self.file and self.file.startswith("file:///"):
|
||||
if not self.file:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
if self.file.startswith("file:///"):
|
||||
bs64_data = file_to_base64(self.file[8:])
|
||||
elif self.file and self.file.startswith("http"):
|
||||
elif self.file.startswith("http"):
|
||||
file_path = await download_image_by_url(self.file)
|
||||
bs64_data = file_to_base64(file_path)
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
elif self.file.startswith("base64://"):
|
||||
bs64_data = self.file
|
||||
elif os.path.exists(self.file):
|
||||
bs64_data = file_to_base64(self.file)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将语音注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Video(BaseMessageComponent):
|
||||
type: ComponentType = "Video"
|
||||
type = ComponentType.Video
|
||||
file: str
|
||||
cover: T.Optional[str] = ""
|
||||
c: T.Optional[int] = 2
|
||||
@@ -205,9 +251,6 @@ class Video(BaseMessageComponent):
|
||||
path: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
# for k in _.keys():
|
||||
# if k == "c" and _[k] not in [2, 3]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -220,15 +263,85 @@ class Video(BaseMessageComponent):
|
||||
return Video(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
async def convert_to_file_path(self) -> str:
|
||||
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。
|
||||
|
||||
Returns:
|
||||
str: 视频的本地路径,以绝对路径表示。
|
||||
"""
|
||||
url = self.file
|
||||
if url and url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url and url.startswith("http"):
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(url, video_file_path)
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
else:
|
||||
raise Exception(f"download failed: {url}")
|
||||
elif os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将视频注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = self.file
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated video file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type: ComponentType = "At"
|
||||
type = ComponentType.At
|
||||
qq: T.Union[int, str] # 此处str为all时代表所有人
|
||||
name: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
def toDict(self):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {"qq": str(self.qq)},
|
||||
}
|
||||
|
||||
|
||||
class AtAll(At):
|
||||
qq: str = "all"
|
||||
@@ -238,28 +351,28 @@ class AtAll(At):
|
||||
|
||||
|
||||
class RPS(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "RPS"
|
||||
type = ComponentType.RPS
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Dice(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Dice"
|
||||
type = ComponentType.Dice
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Shake(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Shake"
|
||||
type = ComponentType.Shake
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Anonymous(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Anonymous"
|
||||
type = ComponentType.Anonymous
|
||||
ignore: T.Optional[bool] = False
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -267,7 +380,7 @@ class Anonymous(BaseMessageComponent): # TODO
|
||||
|
||||
|
||||
class Share(BaseMessageComponent):
|
||||
type: ComponentType = "Share"
|
||||
type = ComponentType.Share
|
||||
url: str
|
||||
title: str
|
||||
content: T.Optional[str] = ""
|
||||
@@ -278,7 +391,7 @@ class Share(BaseMessageComponent):
|
||||
|
||||
|
||||
class Contact(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Contact"
|
||||
type = ComponentType.Contact
|
||||
_type: str # type 字段冲突
|
||||
id: T.Optional[int] = 0
|
||||
|
||||
@@ -287,7 +400,7 @@ class Contact(BaseMessageComponent): # TODO
|
||||
|
||||
|
||||
class Location(BaseMessageComponent): # TODO
|
||||
type: ComponentType = "Location"
|
||||
type = ComponentType.Location
|
||||
lat: float
|
||||
lon: float
|
||||
title: T.Optional[str] = ""
|
||||
@@ -298,7 +411,7 @@ class Location(BaseMessageComponent): # TODO
|
||||
|
||||
|
||||
class Music(BaseMessageComponent):
|
||||
type: ComponentType = "Music"
|
||||
type = ComponentType.Music
|
||||
_type: str
|
||||
id: T.Optional[int] = 0
|
||||
url: T.Optional[str] = ""
|
||||
@@ -315,7 +428,7 @@ class Music(BaseMessageComponent):
|
||||
|
||||
|
||||
class Image(BaseMessageComponent):
|
||||
type: ComponentType = "Image"
|
||||
type = ComponentType.Image
|
||||
file: T.Optional[str] = ""
|
||||
_type: T.Optional[str] = ""
|
||||
subType: T.Optional[int] = 0
|
||||
@@ -358,23 +471,24 @@ class Image(BaseMessageComponent):
|
||||
Returns:
|
||||
str: 图片的本地路径,以绝对路径表示。
|
||||
"""
|
||||
url = self.url if self.url else self.file
|
||||
if url and url.startswith("file:///"):
|
||||
image_file_path = url[8:]
|
||||
return image_file_path
|
||||
elif url and url.startswith("http"):
|
||||
url = self.url or self.file
|
||||
if not url:
|
||||
raise ValueError("No valid file or URL provided")
|
||||
if url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url.startswith("http"):
|
||||
image_file_path = await download_image_by_url(url)
|
||||
return os.path.abspath(image_file_path)
|
||||
elif url and url.startswith("base64://"):
|
||||
elif url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
elif os.path.exists(url):
|
||||
image_file_path = url
|
||||
return os.path.abspath(image_file_path)
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
@@ -385,37 +499,61 @@ class Image(BaseMessageComponent):
|
||||
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||
"""
|
||||
# convert to base64
|
||||
url = self.url if self.url else self.file
|
||||
if url and url.startswith("file:///"):
|
||||
url = self.url or self.file
|
||||
if not url:
|
||||
raise ValueError("No valid file or URL provided")
|
||||
if url.startswith("file:///"):
|
||||
bs64_data = file_to_base64(url[8:])
|
||||
elif url and url.startswith("http"):
|
||||
elif url.startswith("http"):
|
||||
image_file_path = await download_image_by_url(url)
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
elif url and url.startswith("base64://"):
|
||||
elif url.startswith("base64://"):
|
||||
bs64_data = url
|
||||
elif os.path.exists(url):
|
||||
bs64_data = file_to_base64(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将图片注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
type = ComponentType.Reply
|
||||
id: T.Union[str, int]
|
||||
"""所引用的消息 ID"""
|
||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||
"""引用的消息段列表"""
|
||||
"""被引用的消息段列表"""
|
||||
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||
"""引用的消息发送者 ID"""
|
||||
"""被引用的消息对应的发送者的 ID"""
|
||||
sender_nickname: T.Optional[str] = ""
|
||||
"""引用的消息发送者昵称"""
|
||||
"""被引用的消息对应的发送者的昵称"""
|
||||
time: T.Optional[int] = 0
|
||||
"""引用的消息发送时间"""
|
||||
"""被引用的消息发送时间"""
|
||||
message_str: T.Optional[str] = ""
|
||||
"""解析后的纯文本消息字符串"""
|
||||
sender_str: T.Optional[str] = ""
|
||||
"""被引用的消息纯文本"""
|
||||
"""被引用的消息解析后的纯文本消息字符串"""
|
||||
|
||||
text: T.Optional[str] = ""
|
||||
"""deprecated"""
|
||||
@@ -429,7 +567,7 @@ class Reply(BaseMessageComponent):
|
||||
|
||||
|
||||
class RedBag(BaseMessageComponent):
|
||||
type: ComponentType = "RedBag"
|
||||
type = ComponentType.RedBag
|
||||
title: str
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -437,7 +575,7 @@ class RedBag(BaseMessageComponent):
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type: str = ""
|
||||
type: str = ComponentType.Poke
|
||||
id: T.Optional[int] = 0
|
||||
qq: T.Optional[int] = 0
|
||||
|
||||
@@ -447,7 +585,7 @@ class Poke(BaseMessageComponent):
|
||||
|
||||
|
||||
class Forward(BaseMessageComponent):
|
||||
type: ComponentType = "Forward"
|
||||
type = ComponentType.Forward
|
||||
id: str
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -457,46 +595,85 @@ class Forward(BaseMessageComponent):
|
||||
class Node(BaseMessageComponent):
|
||||
"""群合并转发消息"""
|
||||
|
||||
type: ComponentType = "Node"
|
||||
type = ComponentType.Node
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[int] = 0 # qq号
|
||||
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
|
||||
uin: T.Optional[str] = "0" # qq号
|
||||
content: T.Optional[list[BaseMessageComponent]] = []
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0
|
||||
time: T.Optional[int] = 0 # 忽略
|
||||
|
||||
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
|
||||
if isinstance(content, list):
|
||||
_content = None
|
||||
if all(isinstance(item, Node) for item in content):
|
||||
_content = [node.toDict() for node in content]
|
||||
else:
|
||||
_content = ""
|
||||
for chain in content:
|
||||
_content += chain.toString()
|
||||
content = _content
|
||||
elif isinstance(content, Node):
|
||||
content = content.toDict()
|
||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||
if isinstance(content, Node):
|
||||
# back
|
||||
content = [content]
|
||||
super().__init__(content=content, **_)
|
||||
|
||||
def toString(self):
|
||||
# logger.warn("Protocol: node doesn't support stringify")
|
||||
return ""
|
||||
async def to_dict(self):
|
||||
data_content = []
|
||||
for comp in self.content:
|
||||
if isinstance(comp, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await comp.convert_to_base64()
|
||||
data_content.append(
|
||||
{
|
||||
"type": comp.type.lower(),
|
||||
"data": {"file": f"base64://{bs64}"},
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Plain):
|
||||
# For Plain segments, we need to handle the plain differently
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
elif isinstance(comp, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
elif isinstance(comp, (Node, Nodes)):
|
||||
# For Node segments, we recursively convert them to dict
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
else:
|
||||
d = comp.toDict()
|
||||
data_content.append(d)
|
||||
return {
|
||||
"type": "node",
|
||||
"data": {
|
||||
"user_id": str(self.uin),
|
||||
"nickname": self.name,
|
||||
"content": data_content,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Nodes(BaseMessageComponent):
|
||||
type: ComponentType = "Nodes"
|
||||
type = ComponentType.Nodes
|
||||
nodes: T.List[Node]
|
||||
|
||||
def __init__(self, nodes: T.List[Node], **_):
|
||||
super().__init__(nodes=nodes, **_)
|
||||
|
||||
def toDict(self):
|
||||
return {"messages": [node.toDict() for node in self.nodes]}
|
||||
"""Deprecated. Use to_dict instead"""
|
||||
ret = {
|
||||
"messages": [],
|
||||
}
|
||||
for node in self.nodes:
|
||||
d = node.toDict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
d = await node.to_dict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
|
||||
class Xml(BaseMessageComponent):
|
||||
type: ComponentType = "Xml"
|
||||
type = ComponentType.Xml
|
||||
data: str
|
||||
resid: T.Optional[int] = 0
|
||||
|
||||
@@ -505,7 +682,7 @@ class Xml(BaseMessageComponent):
|
||||
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type: ComponentType = "Json"
|
||||
type = ComponentType.Json
|
||||
data: T.Union[str, dict]
|
||||
resid: T.Optional[int] = 0
|
||||
|
||||
@@ -516,7 +693,7 @@ class Json(BaseMessageComponent):
|
||||
|
||||
|
||||
class CardImage(BaseMessageComponent):
|
||||
type: ComponentType = "CardImage"
|
||||
type = ComponentType.CardImage
|
||||
file: str
|
||||
cache: T.Optional[bool] = True
|
||||
minwidth: T.Optional[int] = 400
|
||||
@@ -535,7 +712,7 @@ class CardImage(BaseMessageComponent):
|
||||
|
||||
|
||||
class TTS(BaseMessageComponent):
|
||||
type: ComponentType = "TTS"
|
||||
type = ComponentType.TTS
|
||||
text: str
|
||||
|
||||
def __init__(self, **_):
|
||||
@@ -543,7 +720,7 @@ class TTS(BaseMessageComponent):
|
||||
|
||||
|
||||
class Unknown(BaseMessageComponent):
|
||||
type: ComponentType = "Unknown"
|
||||
type = ComponentType.Unknown
|
||||
text: str
|
||||
|
||||
def toString(self):
|
||||
@@ -552,19 +729,140 @@ class Unknown(BaseMessageComponent):
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
"""
|
||||
目前此消息段只适配了 Napcat。
|
||||
文件消息段
|
||||
"""
|
||||
|
||||
type: ComponentType = "File"
|
||||
type = ComponentType.File
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file: T.Optional[str] = "" # url(本地路径)
|
||||
file_: T.Optional[str] = "" # 本地路径
|
||||
url: T.Optional[str] = "" # url
|
||||
|
||||
def __init__(self, name: str, file: str):
|
||||
super().__init__(name=name, file=file)
|
||||
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||
"""文件消息段。"""
|
||||
super().__init__(name=name, file_=file, url=url)
|
||||
|
||||
@property
|
||||
def file(self) -> str:
|
||||
"""
|
||||
获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
"""
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段"
|
||||
)
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
@file.setter
|
||||
def file(self, value: str):
|
||||
"""
|
||||
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||
|
||||
Args:
|
||||
value (str): 文件路径或URL
|
||||
"""
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
self.url = value
|
||||
else:
|
||||
self.file_ = value
|
||||
|
||||
async def get_file(self, allow_return_url: bool = False) -> str:
|
||||
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
|
||||
|
||||
Args:
|
||||
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
|
||||
注意,如果为 True,也可能返回文件路径。
|
||||
Returns:
|
||||
str: 文件路径或者 http 下载链接
|
||||
"""
|
||||
if allow_return_url and self.url:
|
||||
return self.url
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将文件注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.get_file()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = await self.get_file(allow_return_url=True)
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {
|
||||
"name": self.name,
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
type: ComponentType = "WechatEmoji"
|
||||
type = ComponentType.WechatEmoji
|
||||
md5: T.Optional[str] = ""
|
||||
md5_len: T.Optional[int] = 0
|
||||
cdnurl: T.Optional[str] = ""
|
||||
|
||||
@@ -24,6 +24,8 @@ class MessageChain:
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
type: Optional[str] = None
|
||||
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
||||
|
||||
def message(self, message: str):
|
||||
"""添加一条文本消息到消息链 `chain` 中。
|
||||
@@ -98,6 +100,15 @@ class MessageChain:
|
||||
self.chain.append(Image.fromFileSystem(path))
|
||||
return self
|
||||
|
||||
def base64_image(self, base64_str: str):
|
||||
"""添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。
|
||||
Example:
|
||||
|
||||
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
|
||||
"""
|
||||
self.chain.append(Image.fromBase64(base64_str))
|
||||
return self
|
||||
|
||||
def use_t2i(self, use_t2i: bool):
|
||||
"""设置是否使用文本转图片服务。
|
||||
|
||||
@@ -157,7 +168,7 @@ class ResultContentType(enum.Enum):
|
||||
"""普通的消息结果"""
|
||||
STREAMING_RESULT = enum.auto()
|
||||
"""调用 LLM 产生的流式结果"""
|
||||
STREAMING_FINISH= enum.auto()
|
||||
STREAMING_FINISH = enum.auto()
|
||||
"""流式输出完成"""
|
||||
|
||||
|
||||
|
||||
183
astrbot/core/persona_mgr.py
Normal file
183
astrbot/core/persona_mgr.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Persona, Personality
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot import logger
|
||||
|
||||
DEFAULT_PERSONALITY = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
begin_dialogs=[],
|
||||
mood_imitation_dialogs=[],
|
||||
tools=None,
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed="",
|
||||
)
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
|
||||
self.db = db_helper
|
||||
self.acm = acm
|
||||
default_ps = acm.default_conf.get("provider_settings", {})
|
||||
self.default_persona: str = default_ps.get("default_personality", "default")
|
||||
self.personas: list[Persona] = []
|
||||
self.selected_default_persona: Persona | None = None
|
||||
|
||||
self.personas_v3: list[Personality] = []
|
||||
self.selected_default_persona_v3: Personality | None = None
|
||||
self.persona_v3_config: list[dict] = []
|
||||
|
||||
async def initialize(self):
|
||||
self.personas = await self.get_all_personas()
|
||||
self.get_v3_persona_data()
|
||||
logger.info(f"已加载 {len(self.personas)} 个人格。")
|
||||
|
||||
async def get_persona(self, persona_id: str):
|
||||
"""获取指定 persona 的信息"""
|
||||
persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
return persona
|
||||
|
||||
async def get_default_persona_v3(
|
||||
self, umo: str | MessageSession | None = None
|
||||
) -> Personality:
|
||||
"""获取默认 persona"""
|
||||
cfg = self.acm.get_conf(umo)
|
||||
default_persona_id = cfg.get("provider_settings", {}).get(
|
||||
"default_personality", "default"
|
||||
)
|
||||
if not default_persona_id or default_persona_id == "default":
|
||||
return DEFAULT_PERSONALITY
|
||||
try:
|
||||
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
|
||||
except Exception:
|
||||
return DEFAULT_PERSONALITY
|
||||
|
||||
async def delete_persona(self, persona_id: str):
|
||||
"""删除指定 persona"""
|
||||
if not await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
await self.db.delete_persona(persona_id)
|
||||
self.personas = [p for p in self.personas if p.persona_id != persona_id]
|
||||
self.get_v3_persona_data()
|
||||
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
):
|
||||
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
existing_persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not existing_persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
persona = await self.db.update_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
if persona:
|
||||
for i, p in enumerate(self.personas):
|
||||
if p.persona_id == persona_id:
|
||||
self.personas[i] = persona
|
||||
break
|
||||
self.get_v3_persona_data()
|
||||
return persona
|
||||
|
||||
async def get_all_personas(self) -> list[Persona]:
|
||||
"""获取所有 personas"""
|
||||
return await self.db.get_personas()
|
||||
|
||||
async def create_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} already exists.")
|
||||
new_persona = await self.db.insert_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
self.personas.append(new_persona)
|
||||
self.get_v3_persona_data()
|
||||
return new_persona
|
||||
|
||||
def get_v3_persona_data(
|
||||
self,
|
||||
) -> tuple[list[dict], list[Personality], Personality]:
|
||||
"""获取 AstrBot <4.0.0 版本的 persona 数据。
|
||||
|
||||
Returns:
|
||||
- list[dict]: 包含 persona 配置的字典列表。
|
||||
- list[Personality]: 包含 Personality 对象的列表。
|
||||
- Personality: 默认选择的 Personality 对象。
|
||||
"""
|
||||
v3_persona_config = [
|
||||
{
|
||||
"prompt": persona.system_prompt,
|
||||
"name": persona.persona_id,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"mood_imitation_dialogs": [], # deprecated
|
||||
"tools": persona.tools,
|
||||
}
|
||||
for persona in self.personas
|
||||
]
|
||||
|
||||
personas_v3: list[Personality] = []
|
||||
selected_default_persona: Personality | None = None
|
||||
|
||||
for persona_cfg in v3_persona_config:
|
||||
begin_dialogs = persona_cfg.get("begin_dialogs", [])
|
||||
bd_processed = []
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona_cfg,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed="", # deprecated
|
||||
)
|
||||
if persona["name"] == self.default_persona:
|
||||
selected_default_persona = persona
|
||||
personas_v3.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not selected_default_persona and len(personas_v3) > 0:
|
||||
# 默认选择第一个
|
||||
selected_default_persona = personas_v3[0]
|
||||
|
||||
if not selected_default_persona:
|
||||
selected_default_persona = DEFAULT_PERSONALITY
|
||||
personas_v3.append(selected_default_persona)
|
||||
|
||||
self.personas_v3 = personas_v3
|
||||
self.selected_default_persona_v3 = selected_default_persona
|
||||
self.persona_v3_config = v3_persona_config
|
||||
self.selected_default_persona = Persona(
|
||||
persona_id=selected_default_persona["name"],
|
||||
system_prompt=selected_default_persona["prompt"],
|
||||
begin_dialogs=selected_default_persona["begin_dialogs"],
|
||||
tools=selected_default_persona["tools"] or None,
|
||||
)
|
||||
|
||||
return v3_persona_config, personas_v3, selected_default_persona
|
||||
@@ -1,25 +1,25 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
EventResultType,
|
||||
MessageEventResult,
|
||||
)
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
|
||||
# 管道阶段顺序
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
@@ -29,9 +29,9 @@ STAGES_ORDER = [
|
||||
__all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PlatformCompatibilityStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.star import PluginManager
|
||||
from .context_utils import call_handler, call_event_hook
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -9,3 +10,6 @@ class PipelineContext:
|
||||
|
||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||
plugin_manager: PluginManager # 插件管理器对象
|
||||
astrbot_config_id: str
|
||||
call_handler = call_handler
|
||||
call_event_hook = call_event_hook
|
||||
|
||||
98
astrbot/core/pipeline/context_utils.py
Normal file
98
astrbot/core/pipeline/context_utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
from astrbot import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
|
||||
|
||||
async def call_event_hook(
|
||||
event: AstrMessageEvent,
|
||||
hook_type: EventType,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""调用事件钩子函数
|
||||
|
||||
Returns:
|
||||
bool: 如果事件被终止,返回 True
|
||||
# """
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, *args, **kwargs)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
|
||||
return event.is_stopped()
|
||||
@@ -1,56 +0,0 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class PlatformCompatibilityStage(Stage):
|
||||
"""检查所有处理器的平台兼容性。
|
||||
|
||||
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""初始化平台兼容性检查阶段
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||
"""
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
|
||||
# 获取已激活的处理器
|
||||
activated_handlers = event.get_extra("activated_handlers")
|
||||
if activated_handlers is None:
|
||||
activated_handlers = []
|
||||
|
||||
# 标记不兼容的处理器
|
||||
for handler in activated_handlers:
|
||||
if not isinstance(handler, StarHandlerMetadata):
|
||||
continue
|
||||
# 检查处理器是否在当前平台启用
|
||||
enabled = handler.is_enabled_for_platform(platform_id)
|
||||
if not enabled:
|
||||
if handler.handler_module_path in star_map:
|
||||
plugin_name = star_map[handler.handler_module_path].name
|
||||
logger.debug(
|
||||
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||
)
|
||||
# 设置处理器为平台不兼容状态
|
||||
# TODO: 更好的标记方式
|
||||
handler.platform_compatible = False
|
||||
else:
|
||||
# 确保处理器为平台兼容状态
|
||||
handler.platform_compatible = True
|
||||
|
||||
# 更新已激活的处理器列表
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
@@ -43,31 +43,31 @@ class PreProcessStage(Stage):
|
||||
# STT
|
||||
if self.stt_settings.get("enable", False):
|
||||
# TODO: 独立
|
||||
stt_provider = (
|
||||
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
)
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
||||
"""
|
||||
|
||||
from ...context import PipelineContext
|
||||
from ...context import PipelineContext, call_handler
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -33,16 +33,6 @@ class StarRequestSubStage(Stage):
|
||||
handlers_parsed_params = {}
|
||||
|
||||
for handler in activated_handlers:
|
||||
# 检查处理器是否在当前平台兼容
|
||||
if (
|
||||
hasattr(handler, "platform_compatible")
|
||||
and handler.platform_compatible is False
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||
)
|
||||
continue
|
||||
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_path not in star_map:
|
||||
@@ -50,7 +40,7 @@ class StarRequestSubStage(Stage):
|
||||
logger.debug(
|
||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||
)
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||
wrapper = call_handler(event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
|
||||
@@ -58,33 +58,30 @@ class RateLimitStage(Stage):
|
||||
now = datetime.now()
|
||||
|
||||
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
# 检查并处理限流,可能需要多次检查直到满足条件
|
||||
while True:
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
if len(timestamps) < self.rate_limit_count:
|
||||
timestamps.append(now)
|
||||
break
|
||||
else:
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
||||
|
||||
if len(timestamps) >= self.rate_limit_count:
|
||||
# 达到限流阈值,计算下一个窗口的时间
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds()
|
||||
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
self._remove_expired_timestamps(
|
||||
timestamps, now + timedelta(seconds=stall_duration)
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
now = datetime.now()
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
def _remove_expired_timestamps(
|
||||
self, timestamps: Deque[datetime], now: datetime
|
||||
|
||||
@@ -12,6 +12,8 @@ from astrbot.core import logger
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -25,37 +27,19 @@ class RespondStage(Stage):
|
||||
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||
Comp.AtAll: lambda comp: True, # @所有人
|
||||
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
|
||||
Comp.Dice: lambda comp: True, # 骰子(未完成)
|
||||
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
|
||||
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
|
||||
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||
Comp.Music: lambda comp: bool(comp._type)
|
||||
and bool(comp.url)
|
||||
and bool(comp.audio), # 音乐
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||
Comp.Node: lambda comp: bool(comp.name)
|
||||
and comp.uin != 0
|
||||
and bool(comp.content), # 一个转发节点
|
||||
Comp.Node: lambda comp: bool(comp.content), # 转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
|
||||
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
|
||||
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
|
||||
Comp.File: lambda comp: bool(comp.file), # 文件
|
||||
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
|
||||
Comp.File: lambda comp: bool(comp.file_ or comp.url),
|
||||
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
|
||||
}
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||
|
||||
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||
"reply_with_mention"
|
||||
@@ -126,8 +110,6 @@ class RespondStage(Stage):
|
||||
if comp_type in self._component_validators:
|
||||
if self._component_validators[comp_type](comp):
|
||||
return False
|
||||
else:
|
||||
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
|
||||
|
||||
# 如果所有组件都为空
|
||||
return True
|
||||
@@ -143,27 +125,42 @@ class RespondStage(Stage):
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果直接交付平台适配器处理
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
await event._pre_send()
|
||||
await event.send_streaming(result.async_stream)
|
||||
await event._post_send()
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_id()})")
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
await event._pre_send()
|
||||
# 检查路径映射
|
||||
if mappings := self.platform_settings.get("path_mapping", []):
|
||||
for idx, component in enumerate(result.chain):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
if await self._is_empty_message_chain(result.chain):
|
||||
logger.info("消息为空,跳过发送阶段")
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
if self.enable_seg and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||
non_record_comps = [
|
||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||
]
|
||||
|
||||
if (
|
||||
self.enable_seg
|
||||
and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
)
|
||||
and event.get_platform_name()
|
||||
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
||||
):
|
||||
decorated_comps = []
|
||||
if self.reply_with_mention:
|
||||
@@ -178,27 +175,46 @@ class RespondStage(Stage):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
|
||||
# leverage lock to guarentee the order of message sending among different events
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
# 分段回复
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
decorated_comps = [] # 清空已发送的装饰组件
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
try:
|
||||
await event.send(MessageChain([*decorated_comps, comp]))
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
|
||||
try:
|
||||
await event.send(result)
|
||||
await event.send(MessageChain(non_record_comps))
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
await event._post_send()
|
||||
|
||||
logger.info(
|
||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
)
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -34,6 +36,7 @@ class ResultDecorateStage(Stage):
|
||||
self.t2i_word_threshold = 150
|
||||
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
||||
self.t2i_use_network = self.t2i_strategy == "remote"
|
||||
self.t2i_active_template = ctx.astrbot_config["t2i_active_template"]
|
||||
|
||||
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
||||
"forward_threshold"
|
||||
@@ -62,9 +65,10 @@ class ResultDecorateStage(Stage):
|
||||
]
|
||||
self.content_safe_check_stage = None
|
||||
if self.content_safe_check_reply:
|
||||
for stage in registered_stages:
|
||||
if stage.__class__.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage
|
||||
for stage_cls in registered_stages:
|
||||
if stage_cls.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -96,7 +100,7 @@ class ResultDecorateStage(Stage):
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||
EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
@@ -140,7 +144,11 @@ class ResultDecorateStage(Stage):
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
if self.enable_segmented_reply:
|
||||
if self.enable_segmented_reply and event.get_platform_name() not in [
|
||||
"qq_official",
|
||||
"weixin_official_account",
|
||||
"dingtalk",
|
||||
]:
|
||||
if (
|
||||
self.only_llm_result and result.is_llm_result()
|
||||
) or not self.only_llm_result:
|
||||
@@ -151,9 +159,9 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
split_response = []
|
||||
for line in comp.text.split("\n"):
|
||||
split_response.extend(re.findall(self.regex, line))
|
||||
split_response = re.findall(
|
||||
self.regex, comp.text, re.DOTALL | re.MULTILINE
|
||||
)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
@@ -168,28 +176,57 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + comp.text)
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(
|
||||
Record(file=audio_path, url=audio_path)
|
||||
)
|
||||
else:
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
@@ -211,7 +248,10 @@ class ResultDecorateStage(Stage):
|
||||
render_start = time.time()
|
||||
try:
|
||||
url = await html_renderer.render_t2i(
|
||||
plain_str, return_url=True, use_network=self.t2i_use_network
|
||||
plain_str,
|
||||
return_url=True,
|
||||
use_network=self.t2i_use_network,
|
||||
template_name=self.t2i_active_template,
|
||||
)
|
||||
except BaseException:
|
||||
logger.error("文本转图片失败,使用文本发送。")
|
||||
@@ -223,6 +263,14 @@ class ResultDecorateStage(Stage):
|
||||
if url:
|
||||
if url.startswith("http"):
|
||||
result.chain = [Image.fromURL(url)]
|
||||
elif (
|
||||
self.ctx.astrbot_config["t2i_use_file_service"]
|
||||
and self.ctx.astrbot_config["callback_api_base"]
|
||||
):
|
||||
token = await file_token_service.register_file(url)
|
||||
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
result.chain = [Image.fromURL(url)]
|
||||
else:
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
|
||||
@@ -11,16 +11,17 @@ class PipelineScheduler:
|
||||
|
||||
def __init__(self, context: PipelineContext):
|
||||
registered_stages.sort(
|
||||
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||
key=lambda x: STAGES_ORDER.index(x.__name__)
|
||||
) # 按照顺序排序
|
||||
self.ctx = context # 上下文对象
|
||||
self.stages = [] # 存储阶段实例
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管道调度器时, 初始化所有阶段"""
|
||||
for stage in registered_stages:
|
||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
|
||||
await stage.initialize(self.ctx)
|
||||
for stage_cls in registered_stages:
|
||||
stage_instance = stage_cls() # 创建实例
|
||||
await stage_instance.initialize(self.ctx)
|
||||
self.stages.append(stage_instance)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
"""依次执行各个阶段
|
||||
@@ -29,9 +30,9 @@ class PipelineScheduler:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||
"""
|
||||
for i in range(from_stage, len(registered_stages)):
|
||||
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
for i in range(from_stage, len(self.stages)):
|
||||
stage = self.stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
|
||||
coroutine = stage.process(
|
||||
event
|
||||
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||
@@ -73,7 +74,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||
if event.get_platform_name() == "webchat":
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
22
astrbot/core/pipeline/session_status_check/stage.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class SessionStatusCheckStage(Stage):
|
||||
"""检查会话是否整体启用"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
event.stop_event()
|
||||
@@ -1,19 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import inspect
|
||||
import traceback
|
||||
from astrbot.api import logger
|
||||
from typing import List, AsyncGenerator, Union, Awaitable
|
||||
from typing import List, AsyncGenerator, Union, Type
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
|
||||
|
||||
|
||||
def register_stage(cls):
|
||||
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
|
||||
registered_stages.append(cls())
|
||||
registered_stages.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
@@ -41,70 +37,3 @@ class Stage(abc.ABC):
|
||||
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _call_handler(
|
||||
self,
|
||||
ctx: PipelineContext,
|
||||
event: AstrMessageEvent,
|
||||
handler: Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 待处理的事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
*args: 传递给handler的位置参数
|
||||
**kwargs: 传递给handler的关键字参数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器(async def)
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
# 如果是一个异步生成器, 进入洋葱模型
|
||||
_has_yielded = False # 是否返回过值
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到yield分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -35,10 +38,24 @@ class WakingCheckStage(Stage):
|
||||
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
|
||||
"platform_settings"
|
||||
].get("friend_message_needs_wake_prefix", False)
|
||||
# 是否忽略机器人自己发送的消息
|
||||
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_bot_self_message", False
|
||||
)
|
||||
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_at_all", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if (
|
||||
self.ignore_bot_self_message
|
||||
and event.get_self_id() == event.get_sender_id()
|
||||
):
|
||||
# 忽略机器人自己发送的消息
|
||||
event.stop_event()
|
||||
return
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||
@@ -66,11 +83,18 @@ class WakingCheckStage(Stage):
|
||||
event.message_str = event.message_str[len(wake_prefix) :].strip()
|
||||
break
|
||||
if not is_wake:
|
||||
# 检查是否有 at 消息
|
||||
# 检查是否有at消息 / at全体成员消息 / 引用了bot的消息
|
||||
for message in messages:
|
||||
if isinstance(message, At) and (
|
||||
str(message.qq) == str(event.get_self_id())
|
||||
or str(message.qq) == "all"
|
||||
if (
|
||||
(
|
||||
isinstance(message, At)
|
||||
and (str(message.qq) == str(event.get_self_id()))
|
||||
)
|
||||
or (isinstance(message, AtAll) and not self.ignore_at_all)
|
||||
or (
|
||||
isinstance(message, Reply)
|
||||
and str(message.sender_id) == str(event.get_self_id())
|
||||
)
|
||||
):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
@@ -88,8 +112,17 @@ class WakingCheckStage(Stage):
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
|
||||
# 将 plugins_name 设置到 event 中
|
||||
enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"])
|
||||
if enabled_plugins_name == ["*"]:
|
||||
# 如果是 *,则表示所有插件都启用
|
||||
event.plugins_name = None
|
||||
else:
|
||||
event.plugins_name = enabled_plugins_name
|
||||
logger.debug(f"enabled_plugins_name: {enabled_plugins_name}")
|
||||
|
||||
for handler in star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.AdapterMessageEvent
|
||||
EventType.AdapterMessageEvent, plugins_name=event.plugins_name
|
||||
):
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
@@ -114,7 +147,6 @@ class WakingCheckStage(Stage):
|
||||
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
event.stop_event()
|
||||
passed = False
|
||||
break
|
||||
@@ -126,10 +158,9 @@ class WakingCheckStage(Stage):
|
||||
if self.no_permission_reply:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
||||
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
||||
)
|
||||
@@ -145,7 +176,12 @@ class WakingCheckStage(Stage):
|
||||
"parsed_params"
|
||||
)
|
||||
|
||||
event.clear_extra()
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
event, activated_handlers
|
||||
)
|
||||
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
event.set_extra("handlers_parsed_params", handlers_parsed_params)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.components import (
|
||||
Plain,
|
||||
@@ -20,21 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSesion:
|
||||
platform_name: str
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
@staticmethod
|
||||
def from_str(session_str: str):
|
||||
platform_name, message_type, session_id = session_str.split(":")
|
||||
return MessageSesion(platform_name, MessageType(message_type), session_id)
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
@@ -61,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -75,13 +65,23 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.call_llm = False
|
||||
"""是否在此消息事件中禁止默认的 LLM 请求"""
|
||||
|
||||
self.plugins_name: list[str] | None = None
|
||||
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
|
||||
|
||||
# back_compability
|
||||
self.platform = platform_meta
|
||||
|
||||
def get_platform_name(self):
|
||||
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
|
||||
|
||||
NOTE: 用户可能会同时运行多个相同类型的平台适配器。"""
|
||||
return self.platform_meta.name
|
||||
|
||||
def get_platform_id(self):
|
||||
"""获取这个事件所属的平台的 ID。
|
||||
|
||||
NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。
|
||||
"""
|
||||
return self.platform_meta.id
|
||||
|
||||
def get_message_str(self) -> str:
|
||||
@@ -185,6 +185,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
清除额外的信息。
|
||||
"""
|
||||
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
|
||||
self._extras.clear()
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
@@ -205,9 +206,26 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
return self.role == "admin"
|
||||
|
||||
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
|
||||
async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
|
||||
"""
|
||||
将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。
|
||||
"""
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
return buffer
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
@@ -215,10 +233,10 @@ class AstrMessageEvent(abc.ABC):
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
"""调度器会在执行 send() 前调用该方法"""
|
||||
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
|
||||
|
||||
async def _post_send(self):
|
||||
"""调度器会在执行 send() 后调用该方法"""
|
||||
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
|
||||
|
||||
def set_result(self, result: Union[MessageEventResult, str]):
|
||||
"""设置消息事件的结果。
|
||||
@@ -384,8 +402,13 @@ class AstrMessageEvent(abc.ABC):
|
||||
Args:
|
||||
message (MessageChain): 消息链,具体使用方式请参考文档。
|
||||
"""
|
||||
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
|
||||
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
|
||||
sid = str(uuid.UUID(bytes=hash_obj.digest()))
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
Metric.upload(
|
||||
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
|
||||
)
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
@@ -394,7 +417,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
适配情况:
|
||||
|
||||
- gewechat
|
||||
- aiocqhttp(OneBotv11)
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -18,6 +18,9 @@ class PlatformManager:
|
||||
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
|
||||
这个配置中的 unique_session 需要特殊处理,
|
||||
约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
|
||||
self.event_queue = event_queue
|
||||
|
||||
async def initialize(self):
|
||||
@@ -58,9 +61,9 @@ class PlatformManager:
|
||||
from .sources.qqofficial_webhook.qo_webhook_adapter import (
|
||||
QQOfficialWebhookPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import (
|
||||
GewechatPlatformAdapter, # noqa: F401
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
)
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||
@@ -72,6 +75,18 @@ class PlatformManager:
|
||||
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||
case "wecom":
|
||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||
WeixinOfficialAccountPlatformAdapter, # noqa
|
||||
)
|
||||
case "discord":
|
||||
from .sources.discord.discord_platform_adapter import (
|
||||
DiscordPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "slack":
|
||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||
case "satori":
|
||||
from .sources.satori.satori_adapter import SatoriPlatformAdapter # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||
|
||||
28
astrbot/core/platform/message_session.py
Normal file
28
astrbot/core/platform/message_session.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSession:
|
||||
"""描述一条消息在 AstrBot 中对应的会话的唯一标识。
|
||||
如果您需要实例化 MessageSession,请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。"""
|
||||
|
||||
platform_name: str
|
||||
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
platform_id: str = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
def __post_init__(self):
|
||||
self.platform_id = self.platform_name
|
||||
|
||||
@staticmethod
|
||||
def from_str(session_str: str):
|
||||
platform_id, message_type, session_id = session_str.split(":")
|
||||
return MessageSession(platform_id, MessageType(message_type), session_id)
|
||||
|
||||
|
||||
MessageSesion = MessageSession # back compatibility
|
||||
@@ -5,7 +5,7 @@ from asyncio import Queue
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .astr_message_event import MessageSesion
|
||||
from .message_session import MessageSesion
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
@dataclass
|
||||
class PlatformMetadata:
|
||||
name: str
|
||||
"""平台的名称"""
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str = None
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import asyncio
|
||||
import typing
|
||||
import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Node,
|
||||
Nodes,
|
||||
Plain,
|
||||
Record,
|
||||
Video,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
|
||||
from aiocqhttp import CQHttp
|
||||
|
||||
|
||||
class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
@@ -13,87 +23,167 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
|
||||
"""修复部分字段"""
|
||||
if isinstance(segment, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
return {
|
||||
"type": segment.type.lower(),
|
||||
"data": {
|
||||
"file": f"base64://{bs64}",
|
||||
},
|
||||
}
|
||||
elif isinstance(segment, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
elif isinstance(segment, Video):
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
else:
|
||||
# For other segments, we simply convert them to a dict by calling toDict
|
||||
return segment.toDict()
|
||||
|
||||
@staticmethod
|
||||
async def _parse_onebot_json(message_chain: MessageChain):
|
||||
"""解析成 OneBot json 格式"""
|
||||
ret = []
|
||||
for segment in message_chain.chain:
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d["type"] = "text"
|
||||
d["data"]["text"] = segment.text.strip()
|
||||
# 如果是空文本或者只带换行符的文本,不发送
|
||||
if not d["data"]["text"]:
|
||||
if not segment.text.strip():
|
||||
continue
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
d["data"] = {
|
||||
"file": bs64,
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d["data"] = {
|
||||
"qq": str(segment.qq) # 转换为字符串
|
||||
}
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
@classmethod
|
||||
async def _dispatch_send(
|
||||
cls,
|
||||
bot: CQHttp,
|
||||
event: Event | None,
|
||||
is_group: bool,
|
||||
session_id: str,
|
||||
messages: list[dict],
|
||||
):
|
||||
# session_id 必须是纯数字字符串
|
||||
session_id = int(session_id) if session_id.isdigit() else None
|
||||
|
||||
if not ret:
|
||||
return
|
||||
|
||||
send_one_by_one = False
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 转发消息不能和普通消息混在一起发送
|
||||
send_one_by_one = True
|
||||
break
|
||||
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 合并转发消息
|
||||
|
||||
if isinstance(seg, Node):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = seg.toDict()
|
||||
if self.get_group_id():
|
||||
payload["group_id"] = self.get_group_id()
|
||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||
else:
|
||||
payload["user_id"] = self.get_sender_id()
|
||||
await self.bot.call_action(
|
||||
"send_private_forward_msg", **payload
|
||||
)
|
||||
else:
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
await AiocqhttpMessageEvent._parse_onebot_json(
|
||||
MessageChain([seg])
|
||||
),
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
if is_group and isinstance(session_id, int):
|
||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||
elif not is_group and isinstance(session_id, int):
|
||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||
elif isinstance(event, Event): # 最后兜底
|
||||
await bot.send(event=event, message=messages)
|
||||
else:
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
raise ValueError(
|
||||
f"无法发送消息:缺少有效的数字 session_id({session_id}) 或 event({event})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def send_message(
|
||||
cls,
|
||||
bot: CQHttp,
|
||||
message_chain: MessageChain,
|
||||
event: Event | None = None,
|
||||
is_group: bool = False,
|
||||
session_id: str = None,
|
||||
):
|
||||
"""发送消息至 QQ 协议端(aiocqhttp)。
|
||||
|
||||
Args:
|
||||
bot (CQHttp): aiocqhttp 机器人实例
|
||||
message_chain (MessageChain): 要发送的消息链
|
||||
event (Event | None, optional): aiocqhttp 事件对象.
|
||||
is_group (bool, optional): 是否为群消息.
|
||||
session_id (str | None, optional): 会话 ID(群号或 QQ 号
|
||||
"""
|
||||
|
||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||
send_one_by_one = any(
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
|
||||
)
|
||||
if not send_one_by_one:
|
||||
ret = await cls._parse_onebot_json(message_chain)
|
||||
if not ret:
|
||||
return
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, ret)
|
||||
return
|
||||
for seg in message_chain.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 合并转发消息
|
||||
if isinstance(seg, Node):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if is_group:
|
||||
payload["group_id"] = session_id
|
||||
await bot.call_action("send_group_forward_msg", **payload)
|
||||
else:
|
||||
payload["user_id"] = session_id
|
||||
await bot.call_action("send_private_forward_msg", **payload)
|
||||
elif isinstance(seg, File):
|
||||
d = await cls._from_segment_to_dict(seg)
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, [d])
|
||||
else:
|
||||
messages = await cls._parse_onebot_json(MessageChain([seg]))
|
||||
if not messages:
|
||||
continue
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, messages)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息"""
|
||||
event = getattr(self.message_obj, "raw_message", None)
|
||||
|
||||
is_group = bool(self.get_group_id())
|
||||
session_id = self.get_group_id() if is_group else self.get_sender_id()
|
||||
|
||||
await self.send_message(
|
||||
bot=self.bot,
|
||||
message_chain=message,
|
||||
event=event, # 不强制要求一定是 Event
|
||||
is_group=is_group,
|
||||
session_id=session_id,
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
if isinstance(group_id, str) and group_id.isdigit():
|
||||
@@ -108,7 +198,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
members: typing.List[typing.Dict] = await self.bot.call_action(
|
||||
members: List[Dict] = await self.bot.call_action(
|
||||
"get_group_member_list",
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import itertools
|
||||
from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.platform import (
|
||||
@@ -20,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -45,7 +44,12 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
|
||||
self.bot = CQHttp(
|
||||
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
||||
use_ws_reverse=True,
|
||||
import_name="aiocqhttp",
|
||||
api_timeout_sec=180,
|
||||
access_token=platform_config.get(
|
||||
"ws_reverse_token"
|
||||
), # 以防旧版本配置不存在
|
||||
)
|
||||
|
||||
@self.bot.on_request()
|
||||
@@ -79,19 +83,18 @@ class AiocqhttpAdapter(Platform):
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
|
||||
match session.message_type.value:
|
||||
case MessageType.GROUP_MESSAGE.value:
|
||||
if "_" in session.session_id:
|
||||
# 独立会话
|
||||
_, group_id = session.session_id.split("_")
|
||||
await self.bot.send_group_msg(group_id=group_id, message=ret)
|
||||
else:
|
||||
await self.bot.send_group_msg(
|
||||
group_id=session.session_id, message=ret
|
||||
)
|
||||
case MessageType.FRIEND_MESSAGE.value:
|
||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
||||
is_group = session.message_type == MessageType.GROUP_MESSAGE
|
||||
if is_group:
|
||||
session_id = session.session_id.split("_")[-1]
|
||||
else:
|
||||
session_id = session.session_id
|
||||
await AiocqhttpMessageEvent.send_message(
|
||||
bot=self.bot,
|
||||
message_chain=message_chain,
|
||||
event=None, # 这里不需要 event,因为是通过 session 发送的
|
||||
is_group=is_group,
|
||||
session_id=session_id,
|
||||
)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
@@ -99,6 +102,9 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if event["post_type"] == "message":
|
||||
abm = await self._convert_handle_message_event(event)
|
||||
if abm.sender.user_id == "2854196310":
|
||||
# 屏蔽 QQ 管家的消息
|
||||
return
|
||||
elif event["post_type"] == "notice":
|
||||
abm = await self._convert_handle_notice_event(event)
|
||||
elif event["post_type"] == "request":
|
||||
@@ -119,6 +125,12 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
@@ -202,82 +214,139 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for m in event.message:
|
||||
t = m["type"]
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
a = None
|
||||
if t == "text":
|
||||
message_str += m["data"]["text"].strip()
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
current_text = "".join(m["data"]["text"] for m in m_group).strip()
|
||||
if not current_text:
|
||||
# 如果文本段为空,则跳过
|
||||
continue
|
||||
message_str += current_text
|
||||
a = ComponentTypes[t](text=current_text) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
elif t == "file":
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
for m in m_group:
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
# Napcat
|
||||
ret = None
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_group_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
group_id=event.group_id,
|
||||
)
|
||||
elif abm.type == MessageType.FRIEND_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_private_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
a = File(name="", url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
path = os.path.join("data/temp", file_name)
|
||||
await download_file(m["data"]["url"], path)
|
||||
|
||||
m["data"] = {"file": path, "name": file_name}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
else:
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(
|
||||
action="get_file",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if not ret.get("file", None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret["file"]):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
|
||||
)
|
||||
|
||||
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
elif t == "reply":
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
for m in m_group:
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
|
||||
reply_event_data["post_type"] = "message"
|
||||
new_event = Event.from_payload(reply_event_data)
|
||||
if not new_event:
|
||||
logger.error(
|
||||
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
|
||||
)
|
||||
continue
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
new_event, get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
elif t == "at":
|
||||
first_at_self_processed = False
|
||||
|
||||
for m in m_group:
|
||||
try:
|
||||
if m["data"]["qq"] == "all":
|
||||
abm.message.append(At(qq="all", name="全体成员"))
|
||||
continue
|
||||
|
||||
at_info = await self.bot.call_action(
|
||||
action="get_group_member_info",
|
||||
group_id=event.group_id,
|
||||
user_id=int(m["data"]["qq"]),
|
||||
no_cache=False,
|
||||
)
|
||||
if at_info:
|
||||
nickname = at_info.get("card", "")
|
||||
if nickname == "":
|
||||
at_info = await self.bot.call_action(
|
||||
action="get_stranger_info",
|
||||
user_id=int(m["data"]["qq"]),
|
||||
no_cache=False,
|
||||
)
|
||||
nickname = at_info.get("nick", "") or at_info.get("nickname", "")
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
At(
|
||||
qq=m["data"]["qq"],
|
||||
name=nickname,
|
||||
)
|
||||
)
|
||||
|
||||
if is_at_self and not first_at_self_processed:
|
||||
# 第一个@是机器人,不添加到message_str
|
||||
first_at_self_processed = True
|
||||
else:
|
||||
# 非第一个@机器人或@其他用户,添加到message_str
|
||||
message_str += f" @{nickname}({m['data']['qq']}) "
|
||||
else:
|
||||
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
else:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -19,6 +20,7 @@ from ...register import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from dingtalk_stream import AckMessage
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||
@@ -152,7 +154,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
"downloadCode": download_code,
|
||||
"robotCode": robot_code,
|
||||
}
|
||||
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||
|
||||
@@ -26,42 +26,42 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"AstrBot",
|
||||
segment.text,
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
if segment.file and segment.file.startswith("file:///"):
|
||||
logger.warning(
|
||||
"dingtalk only support url image, not: " + segment.file
|
||||
)
|
||||
continue
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
markdown_str += f"\n\n"
|
||||
elif segment.file and segment.file.startswith("base64://"):
|
||||
logger.warning("dingtalk only support url image, not base64")
|
||||
continue
|
||||
else:
|
||||
logger.warning(
|
||||
"dingtalk only support url image, not: " + segment.file
|
||||
)
|
||||
continue
|
||||
|
||||
ret = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
try:
|
||||
if not segment.file:
|
||||
logger.warning("钉钉图片 segment 缺少 file 字段,跳过")
|
||||
continue
|
||||
if segment.file.startswith(("http://", "https://")):
|
||||
image_url = segment.file
|
||||
else:
|
||||
image_url = await segment.register_to_file_service()
|
||||
|
||||
markdown_str = f"\n\n"
|
||||
|
||||
ret = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"钉钉图片处理失败: {e}")
|
||||
logger.warning(f"跳过图片发送: {image_path}")
|
||||
continue
|
||||
async def send(self, message: MessageChain):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
@@ -72,4 +72,4 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
126
astrbot/core/platform/sources/discord/client.py
Normal file
126
astrbot/core/platform/sources/discord/client.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import discord
|
||||
from astrbot import logger
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# Discord Bot客户端
|
||||
class DiscordBotClient(discord.Bot):
|
||||
"""Discord客户端封装"""
|
||||
|
||||
def __init__(self, token: str, proxy: str = None):
|
||||
self.token = token
|
||||
self.proxy = proxy
|
||||
|
||||
# 设置Intent权限,遵循权限最小化原则
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True # 订阅消息内容事件 (Privileged)
|
||||
intents.members = True # 订阅成员事件 (Privileged)
|
||||
|
||||
# 初始化Bot
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
|
||||
# 回调函数
|
||||
self.on_message_received = None
|
||||
self.on_ready_once_callback = None
|
||||
self._ready_once_fired = False
|
||||
|
||||
@override
|
||||
async def on_ready(self):
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||
logger.info("[Discord] 客户端已准备就绪。")
|
||||
|
||||
if self.on_ready_once_callback and not self._ready_once_fired:
|
||||
self._ready_once_fired = True
|
||||
try:
|
||||
await self.on_ready_once_callback()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
is_mentioned = self.user in message.mentions
|
||||
return {
|
||||
"message": message,
|
||||
"bot_id": str(self.user.id),
|
||||
"content": message.content,
|
||||
"username": message.author.display_name,
|
||||
"userid": str(message.author.id),
|
||||
"message_id": str(message.id),
|
||||
"channel_id": str(message.channel.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
"type": "message",
|
||||
"is_mentioned": is_mentioned,
|
||||
"clean_content": message.clean_content,
|
||||
}
|
||||
|
||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||
"""从 discord.Interaction 创建数据字典"""
|
||||
return {
|
||||
"interaction": interaction,
|
||||
"bot_id": str(self.user.id),
|
||||
"content": self._extract_interaction_content(interaction),
|
||||
"username": interaction.user.display_name,
|
||||
"userid": str(interaction.user.id),
|
||||
"message_id": str(interaction.id),
|
||||
"channel_id": str(interaction.channel_id)
|
||||
if interaction.channel_id
|
||||
else None,
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
|
||||
)
|
||||
|
||||
if self.on_message_received:
|
||||
message_data = self._create_message_data(message)
|
||||
await self.on_message_received(message_data)
|
||||
|
||||
|
||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||
"""从交互中提取内容"""
|
||||
interaction_type = interaction.type
|
||||
interaction_data = getattr(interaction, "data", {})
|
||||
|
||||
if not interaction_data:
|
||||
return ""
|
||||
|
||||
if interaction_type == discord.InteractionType.application_command:
|
||||
command_name = interaction_data.get("name", "")
|
||||
if options := interaction_data.get("options", []):
|
||||
params = " ".join(
|
||||
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
|
||||
)
|
||||
return f"/{command_name} {params}"
|
||||
return f"/{command_name}"
|
||||
|
||||
elif interaction_type == discord.InteractionType.component:
|
||||
custom_id = interaction_data.get("custom_id", "")
|
||||
component_type = interaction_data.get("component_type", "")
|
||||
return f"component:{custom_id}:{component_type}"
|
||||
|
||||
return str(interaction_data)
|
||||
|
||||
async def start_polling(self):
|
||||
"""开始轮询消息,这是个阻塞方法"""
|
||||
await self.start(self.token)
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
if not self.is_closed():
|
||||
await super().close()
|
||||
133
astrbot/core/platform/sources/discord/components.py
Normal file
133
astrbot/core/platform/sources/discord/components.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import discord
|
||||
from typing import List
|
||||
from astrbot.api.message_components import BaseMessageComponent
|
||||
|
||||
|
||||
# Discord专用组件
|
||||
class DiscordEmbed(BaseMessageComponent):
|
||||
"""Discord Embed消息组件"""
|
||||
|
||||
type: str = "discord_embed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str = None,
|
||||
description: str = None,
|
||||
color: int = None,
|
||||
url: str = None,
|
||||
thumbnail: str = None,
|
||||
image: str = None,
|
||||
footer: str = None,
|
||||
fields: List[dict] = None,
|
||||
):
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.color = color
|
||||
self.url = url
|
||||
self.thumbnail = thumbnail
|
||||
self.image = image
|
||||
self.footer = footer
|
||||
self.fields = fields or []
|
||||
|
||||
def to_discord_embed(self) -> discord.Embed:
|
||||
"""转换为Discord Embed对象"""
|
||||
embed = discord.Embed()
|
||||
|
||||
if self.title:
|
||||
embed.title = self.title
|
||||
if self.description:
|
||||
embed.description = self.description
|
||||
if self.color:
|
||||
embed.color = self.color
|
||||
if self.url:
|
||||
embed.url = self.url
|
||||
if self.thumbnail:
|
||||
embed.set_thumbnail(url=self.thumbnail)
|
||||
if self.image:
|
||||
embed.set_image(url=self.image)
|
||||
if self.footer:
|
||||
embed.set_footer(text=self.footer)
|
||||
|
||||
for field in self.fields:
|
||||
embed.add_field(
|
||||
name=field.get("name", ""),
|
||||
value=field.get("value", ""),
|
||||
inline=field.get("inline", False),
|
||||
)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
class DiscordButton(BaseMessageComponent):
|
||||
"""Discord按钮组件"""
|
||||
|
||||
type: str = "discord_button"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
custom_id: str = None,
|
||||
style: str = "primary",
|
||||
emoji: str = None,
|
||||
url: str = None,
|
||||
disabled: bool = False,
|
||||
):
|
||||
self.label = label
|
||||
self.custom_id = custom_id
|
||||
self.style = style
|
||||
self.emoji = emoji
|
||||
self.url = url
|
||||
self.disabled = disabled
|
||||
|
||||
class DiscordReference(BaseMessageComponent):
|
||||
"""Discord引用组件"""
|
||||
type: str = "discord_reference"
|
||||
def __init__(self, message_id: str, channel_id: str):
|
||||
self.message_id = message_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
|
||||
class DiscordView(BaseMessageComponent):
|
||||
"""Discord视图组件,包含按钮和选择菜单"""
|
||||
|
||||
type: str = "discord_view"
|
||||
|
||||
def __init__(
|
||||
self, components: List[BaseMessageComponent] = None, timeout: float = None
|
||||
):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def to_discord_view(self) -> discord.ui.View:
|
||||
"""转换为Discord View对象"""
|
||||
view = discord.ui.View(timeout=self.timeout)
|
||||
|
||||
for component in self.components:
|
||||
if isinstance(component, DiscordButton):
|
||||
button_style = getattr(
|
||||
discord.ButtonStyle, component.style, discord.ButtonStyle.primary
|
||||
)
|
||||
|
||||
if component.url:
|
||||
# URL按钮
|
||||
button = discord.ui.Button(
|
||||
label=component.label,
|
||||
style=discord.ButtonStyle.link,
|
||||
url=component.url,
|
||||
emoji=component.emoji,
|
||||
disabled=component.disabled,
|
||||
)
|
||||
else:
|
||||
# 普通按钮
|
||||
button = discord.ui.Button(
|
||||
label=component.label,
|
||||
style=button_style,
|
||||
custom_id=component.custom_id,
|
||||
emoji=component.emoji,
|
||||
disabled=component.disabled,
|
||||
)
|
||||
|
||||
view.add_item(button)
|
||||
|
||||
return view
|
||||
@@ -0,0 +1,455 @@
|
||||
import asyncio
|
||||
import discord
|
||||
import sys
|
||||
import re
|
||||
from discord.abc import Messageable
|
||||
from discord.channel import DMChannel
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, File
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from .client import DiscordBotClient
|
||||
from .discord_platform_event import DiscordPlatformEvent
|
||||
|
||||
from typing import Any, Tuple
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# 注册平台适配器
|
||||
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
|
||||
class DiscordPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = None
|
||||
self.registered_handlers = []
|
||||
# 指令注册相关
|
||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
|
||||
self.activity_name = self.config.get("discord_activity_name", None)
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self._polling_task = None
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
# 创建一个 message_obj 以便在 event 中使用
|
||||
message_obj = AstrBotMessage()
|
||||
if "_" in session.session_id:
|
||||
session.session_id = session.session_id.split("_")[1]
|
||||
channel_id_str = session.session_id
|
||||
channel = None
|
||||
try:
|
||||
channel_id = int(channel_id_str)
|
||||
channel = self.client.get_channel(channel_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"[Discord] Invalid channel ID format: {channel_id_str}")
|
||||
|
||||
if channel:
|
||||
message_obj.type = self._get_message_type(channel)
|
||||
message_obj.group_id = self._get_channel_id(channel)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[Discord] Can't get channel info for {channel_id_str}, will guess message type."
|
||||
)
|
||||
message_obj.type = MessageType.GROUP_MESSAGE
|
||||
message_obj.group_id = session.session_id
|
||||
|
||||
message_obj.message_str = message_chain.get_plain_text()
|
||||
message_obj.sender = MessageMember(
|
||||
user_id=str(self.client_self_id), nickname=self.client.user.display_name
|
||||
)
|
||||
message_obj.self_id = self.client_self_id
|
||||
message_obj.session_id = session.session_id
|
||||
message_obj.message = message_chain
|
||||
|
||||
# 创建临时事件对象来发送消息
|
||||
temp_event = DiscordPlatformEvent(
|
||||
message_str=message_chain.get_plain_text(),
|
||||
message_obj=message_obj,
|
||||
platform_meta=self.meta(),
|
||||
session_id=session.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
await temp_event.send(message_chain)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""返回平台元数据"""
|
||||
return PlatformMetadata(
|
||||
"discord",
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
default_config_tmpl=self.config,
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
"""主要运行逻辑"""
|
||||
|
||||
# 初始化回调函数
|
||||
async def on_received(message_data):
|
||||
logger.debug(f"[Discord] 收到消息: {message_data}")
|
||||
if self.client_self_id is None:
|
||||
self.client_self_id = message_data.get("bot_id")
|
||||
abm = await self.convert_message(data=message_data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
# 初始化 Discord 客户端
|
||||
token = str(self.config.get("discord_token"))
|
||||
if not token:
|
||||
logger.error("[Discord] Bot Token 未配置。请在配置文件中正确设置 token。")
|
||||
return
|
||||
|
||||
proxy = self.config.get("discord_proxy") or None
|
||||
self.client = DiscordBotClient(token, proxy)
|
||||
self.client.on_message_received = on_received
|
||||
|
||||
async def callback():
|
||||
if self.enable_command_register:
|
||||
await self._collect_and_register_commands()
|
||||
if self.activity_name:
|
||||
await self.client.change_presence(
|
||||
status=discord.Status.online,
|
||||
activity=discord.CustomActivity(name=self.activity_name),
|
||||
)
|
||||
|
||||
self.client.on_ready_once_callback = callback
|
||||
|
||||
try:
|
||||
self._polling_task = asyncio.create_task(self.client.start_polling())
|
||||
await self.shutdown_event.wait()
|
||||
except discord.errors.LoginFailure:
|
||||
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
|
||||
except discord.errors.ConnectionClosed:
|
||||
logger.warning("[Discord] 与 Discord 的连接已关闭。")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 适配器运行时发生意外错误: {e}", exc_info=True)
|
||||
|
||||
def _get_message_type(
|
||||
self, channel: Messageable, guild_id: int | None = None
|
||||
) -> MessageType:
|
||||
"""根据 channel 对象和 guild_id 判断消息类型"""
|
||||
if guild_id is not None:
|
||||
return MessageType.GROUP_MESSAGE
|
||||
if isinstance(channel, DMChannel) or getattr(channel, "guild", None) is None:
|
||||
return MessageType.FRIEND_MESSAGE
|
||||
return MessageType.GROUP_MESSAGE
|
||||
|
||||
def _get_channel_id(self, channel: Messageable) -> str:
|
||||
"""根据 channel 对象获取ID"""
|
||||
return str(getattr(channel, "id", None))
|
||||
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message: discord.Message = data["message"]
|
||||
|
||||
content = message.content
|
||||
|
||||
# 如果机器人被@,移除@部分
|
||||
# 剥离 User Mention (<@id>, <@!id>)
|
||||
if self.client and self.client.user:
|
||||
mention_str = f"<@{self.client.user.id}>"
|
||||
mention_str_nickname = f"<@!{self.client.user.id}>"
|
||||
if content.startswith(mention_str):
|
||||
content = content[len(mention_str) :].lstrip()
|
||||
elif content.startswith(mention_str_nickname):
|
||||
content = content[len(mention_str_nickname) :].lstrip()
|
||||
|
||||
# 剥离 Role Mention(bot 拥有的任一角色被提及,<@&role_id>)
|
||||
if (
|
||||
hasattr(message, "role_mentions")
|
||||
and hasattr(message, "guild")
|
||||
and message.guild
|
||||
):
|
||||
bot_member = (
|
||||
message.guild.get_member(self.client.user.id)
|
||||
if self.client and self.client.user
|
||||
else None
|
||||
)
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
for role in bot_member.roles:
|
||||
role_mention_str = f"<@&{role.id}>"
|
||||
if content.startswith(role_mention_str):
|
||||
content = content[len(role_mention_str) :].lstrip()
|
||||
break # 只剥离第一个匹配的角色 mention
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.type = self._get_message_type(message.channel)
|
||||
abm.group_id = self._get_channel_id(message.channel)
|
||||
abm.message_str = content
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(message.author.id), nickname=message.author.display_name
|
||||
)
|
||||
message_chain = []
|
||||
if abm.message_str:
|
||||
message_chain.append(Plain(text=abm.message_str))
|
||||
if message.attachments:
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith(
|
||||
"image/"
|
||||
):
|
||||
message_chain.append(
|
||||
Image(file=attachment.url, filename=attachment.filename)
|
||||
)
|
||||
else:
|
||||
message_chain.append(
|
||||
File(name=attachment.filename, url=attachment.url)
|
||||
)
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = str(message.channel.id)
|
||||
abm.message_id = str(message.id)
|
||||
return abm
|
||||
|
||||
async def convert_message(self, data: dict) -> AstrBotMessage:
|
||||
"""将平台消息转换成 AstrBotMessage"""
|
||||
# 由于 on_interaction 已被禁用,我们只处理普通消息
|
||||
return self._convert_message_to_abm(data)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None):
|
||||
"""处理消息"""
|
||||
message_event = DiscordPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
interaction_followup_webhook=followup_webhook,
|
||||
)
|
||||
|
||||
# 检查是否为斜杠指令
|
||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||
|
||||
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||
is_mention = False
|
||||
# User Mention
|
||||
if (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
):
|
||||
if self.client.user in message.raw_message.mentions:
|
||||
is_mention = True
|
||||
# Role Mention(Bot 拥有的角色被提及)
|
||||
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
||||
bot_member = None
|
||||
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
||||
try:
|
||||
bot_member = message.raw_message.guild.get_member(
|
||||
self.client.user.id
|
||||
)
|
||||
except Exception:
|
||||
bot_member = None
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
bot_roles = set(bot_member.roles)
|
||||
mentioned_roles = set(message.raw_message.role_mentions)
|
||||
if (
|
||||
bot_roles
|
||||
and mentioned_roles
|
||||
and bot_roles.intersection(mentioned_roles)
|
||||
):
|
||||
is_mention = True
|
||||
|
||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||
if is_slash_command or is_mention:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
@override
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
|
||||
self.shutdown_event.set()
|
||||
# 优先 cancel polling_task
|
||||
if self._polling_task:
|
||||
self._polling_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._polling_task, timeout=10)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[Discord] polling_task 已取消。")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] polling_task 取消异常: {e}")
|
||||
logger.info("[Discord] 正在清理已注册的斜杠指令... (step 2)")
|
||||
# 清理指令
|
||||
if self.enable_command_register and self.client:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.client.sync_commands(
|
||||
commands=[],
|
||||
guild_ids=[self.guild_id] if self.guild_id else None,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
logger.info("[Discord] 指令清理完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
|
||||
logger.info("[Discord] 正在关闭 Discord 客户端... (step 3)")
|
||||
if self.client and hasattr(self.client, "close"):
|
||||
try:
|
||||
await asyncio.wait_for(self.client.close(), timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] 客户端关闭异常: {e}")
|
||||
logger.info("[Discord] 适配器已终止。")
|
||||
|
||||
def register_handler(self, handler_info):
|
||||
"""注册处理器信息"""
|
||||
self.registered_handlers.append(handler_info)
|
||||
|
||||
async def _collect_and_register_commands(self):
|
||||
"""收集所有指令并注册到Discord"""
|
||||
logger.info("[Discord] 开始收集并注册斜杠指令...")
|
||||
registered_commands = []
|
||||
|
||||
for handler_md in star_handlers_registry:
|
||||
if not star_map[handler_md.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_md.event_filters:
|
||||
cmd_info = self._extract_command_info(event_filter, handler_md)
|
||||
if not cmd_info:
|
||||
continue
|
||||
|
||||
cmd_name, description, cmd_filter_instance = cmd_info
|
||||
|
||||
# 创建动态回调
|
||||
callback = self._create_dynamic_callback(cmd_name)
|
||||
|
||||
# 创建一个通用的参数选项来接收所有文本输入
|
||||
options = [
|
||||
discord.Option(
|
||||
name="params",
|
||||
description="指令的所有参数",
|
||||
type=discord.SlashCommandOptionType.string,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
# 创建SlashCommand
|
||||
slash_command = discord.SlashCommand(
|
||||
name=cmd_name,
|
||||
description=description,
|
||||
func=callback,
|
||||
options=options,
|
||||
guild_ids=[self.guild_id] if self.guild_id else None,
|
||||
)
|
||||
self.client.add_application_command(slash_command)
|
||||
registered_commands.append(cmd_name)
|
||||
|
||||
if registered_commands:
|
||||
logger.info(
|
||||
f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}"
|
||||
)
|
||||
else:
|
||||
logger.info("[Discord] 没有发现可注册的指令。")
|
||||
|
||||
# 使用 Pycord 的方法同步指令
|
||||
# 注意:这可能需要一些时间,并且有频率限制
|
||||
await self.client.sync_commands()
|
||||
logger.info("[Discord] 指令同步完成。")
|
||||
|
||||
def _create_dynamic_callback(self, cmd_name: str):
|
||||
"""为每个指令动态创建一个异步回调函数"""
|
||||
|
||||
async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None):
|
||||
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
|
||||
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
|
||||
logger.debug(f"[Discord] 回调函数参数: {ctx}")
|
||||
logger.debug(f"[Discord] 回调函数参数: {params}")
|
||||
message_str_for_filter = cmd_name
|
||||
if params:
|
||||
message_str_for_filter += f" {params}"
|
||||
|
||||
logger.debug(
|
||||
f"[Discord] 斜杠指令 '{cmd_name}' 被触发。 "
|
||||
f"原始参数: '{params}'. "
|
||||
f"构建的指令字符串: '{message_str_for_filter}'"
|
||||
)
|
||||
|
||||
# 尝试立即响应,防止超时
|
||||
followup_webhook = None
|
||||
try:
|
||||
await ctx.defer()
|
||||
followup_webhook = ctx.followup
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}")
|
||||
|
||||
# 2. 构建 AstrBotMessage
|
||||
abm = AstrBotMessage()
|
||||
abm.type = self._get_message_type(ctx.channel, ctx.guild_id)
|
||||
abm.group_id = self._get_channel_id(ctx.channel)
|
||||
abm.message_str = message_str_for_filter
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(ctx.author.id), nickname=ctx.author.display_name
|
||||
)
|
||||
abm.message = [Plain(text=message_str_for_filter)]
|
||||
abm.raw_message = ctx.interaction
|
||||
abm.self_id = self.client_self_id
|
||||
abm.session_id = str(ctx.channel_id)
|
||||
abm.message_id = str(ctx.interaction.id)
|
||||
|
||||
# 3. 将消息和 webhook 分别交给 handle_msg 处理
|
||||
await self.handle_msg(abm, followup_webhook)
|
||||
|
||||
return dynamic_callback
|
||||
|
||||
@staticmethod
|
||||
def _extract_command_info(
|
||||
event_filter: Any, handler_metadata: StarHandlerMetadata
|
||||
) -> Tuple[str, str, CommandFilter] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
# is_group = False
|
||||
cmd_filter_instance = None
|
||||
|
||||
if isinstance(event_filter, CommandFilter):
|
||||
# 暂不支持子指令注册为斜杠指令
|
||||
if (
|
||||
event_filter.parent_command_names
|
||||
and event_filter.parent_command_names != [""]
|
||||
):
|
||||
return None
|
||||
cmd_name = event_filter.command_name
|
||||
cmd_filter_instance = event_filter
|
||||
|
||||
elif isinstance(event_filter, CommandGroupFilter):
|
||||
# 暂不支持指令组直接注册为斜杠指令,因为它们没有 handle 方法
|
||||
return None
|
||||
|
||||
if not cmd_name:
|
||||
return None
|
||||
|
||||
# Discord 斜杠指令名称规范
|
||||
if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name):
|
||||
logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}")
|
||||
return None
|
||||
|
||||
description = handler_metadata.desc or f"指令: {cmd_name}"
|
||||
if len(description) > 100:
|
||||
description = f"{description[:97]}..."
|
||||
|
||||
return cmd_name, description, cmd_filter_instance
|
||||
291
astrbot/core/platform/sources/discord/discord_platform_event.py
Normal file
291
astrbot/core/platform/sources/discord/discord_platform_event.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import asyncio
|
||||
import discord
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, At
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
Reply,
|
||||
)
|
||||
from astrbot import logger
|
||||
from .client import DiscordBotClient
|
||||
from .components import DiscordEmbed, DiscordView
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# 自定义Discord视图组件(兼容旧版本)
|
||||
class DiscordViewComponent(BaseMessageComponent):
|
||||
type: str = "discord_view"
|
||||
|
||||
def __init__(self, view: discord.ui.View):
|
||||
self.view = view
|
||||
|
||||
|
||||
class DiscordPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: DiscordBotClient,
|
||||
interaction_followup_webhook: Optional[discord.Webhook] = None,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
self.interaction_followup_webhook = interaction_followup_webhook
|
||||
|
||||
@override
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息到Discord平台"""
|
||||
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
try:
|
||||
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
kwargs = {}
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
if files:
|
||||
kwargs["files"] = files
|
||||
if view:
|
||||
kwargs["view"] = view
|
||||
if embeds:
|
||||
kwargs["embeds"] = embeds
|
||||
if reference_message_id and not self.interaction_followup_webhook:
|
||||
kwargs["reference"] = self.client.get_message(int(reference_message_id))
|
||||
if not kwargs:
|
||||
logger.debug("[Discord] 尝试发送空消息,已忽略。")
|
||||
return
|
||||
|
||||
# 根据上下文执行发送/回复操作
|
||||
try:
|
||||
# -- 斜杠指令/交互上下文 --
|
||||
if self.interaction_followup_webhook:
|
||||
await self.interaction_followup_webhook.send(**kwargs)
|
||||
|
||||
# -- 常规消息上下文 --
|
||||
else:
|
||||
channel = await self._get_channel()
|
||||
if not channel:
|
||||
return
|
||||
else:
|
||||
await channel.send(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def _get_channel(self) -> Optional[discord.abc.Messageable]:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
channel_id = int(self.session_id)
|
||||
return self.client.get_channel(
|
||||
channel_id
|
||||
) or await self.client.fetch_channel(channel_id)
|
||||
except (ValueError, discord.errors.NotFound, discord.errors.Forbidden):
|
||||
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
|
||||
return None
|
||||
|
||||
async def _parse_to_discord(
|
||||
self,
|
||||
message: MessageChain,
|
||||
) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]:
|
||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||
content = ""
|
||||
files = []
|
||||
view = None
|
||||
embeds = []
|
||||
reference_message_id = None
|
||||
for i in message.chain: # 遍历消息链
|
||||
if isinstance(i, Plain): # 如果是文字类型的
|
||||
content += i.text
|
||||
elif isinstance(i, Reply):
|
||||
reference_message_id = i.id
|
||||
elif isinstance(i, At):
|
||||
content += f"<@{i.qq}>"
|
||||
elif isinstance(i, Image):
|
||||
logger.debug(f"[Discord] 开始处理 Image 组件: {i}")
|
||||
try:
|
||||
filename = getattr(i, "filename", None)
|
||||
file_content = getattr(i, "file", None)
|
||||
|
||||
if not file_content:
|
||||
logger.warning(f"[Discord] Image 组件没有 file 属性: {i}")
|
||||
continue
|
||||
|
||||
discord_file = None
|
||||
|
||||
# 1. URL
|
||||
if file_content.startswith("http"):
|
||||
logger.debug(f"[Discord] 处理 URL 图片: {file_content}")
|
||||
embed = discord.Embed().set_image(url=file_content)
|
||||
embeds.append(embed)
|
||||
continue
|
||||
|
||||
# 2. File URI
|
||||
elif file_content.startswith("file:///"):
|
||||
logger.debug(f"[Discord] 处理 File URI: {file_content}")
|
||||
path = Path(file_content[8:])
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
discord_file = discord.File(
|
||||
BytesIO(file_bytes), filename=filename or path.name
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 图片文件不存在: {path}")
|
||||
|
||||
# 3. Base64 URI
|
||||
elif file_content.startswith("base64://"):
|
||||
logger.debug("[Discord] 处理 Base64 URI")
|
||||
b64_data = file_content.split("base64://", 1)[1]
|
||||
missing_padding = len(b64_data) % 4
|
||||
if missing_padding:
|
||||
b64_data += "=" * (4 - missing_padding)
|
||||
img_bytes = base64.b64decode(b64_data)
|
||||
discord_file = discord.File(
|
||||
BytesIO(img_bytes), filename=filename or "image.png"
|
||||
)
|
||||
|
||||
# 4. 裸 Base64 或本地路径
|
||||
else:
|
||||
try:
|
||||
logger.debug("[Discord] 尝试作为裸 Base64 处理")
|
||||
b64_data = file_content
|
||||
missing_padding = len(b64_data) % 4
|
||||
if missing_padding:
|
||||
b64_data += "=" * (4 - missing_padding)
|
||||
img_bytes = base64.b64decode(b64_data)
|
||||
discord_file = discord.File(
|
||||
BytesIO(img_bytes), filename=filename or "image.png"
|
||||
)
|
||||
except (ValueError, TypeError, base64.binascii.Error):
|
||||
logger.debug(
|
||||
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}"
|
||||
)
|
||||
path = Path(file_content)
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
discord_file = discord.File(
|
||||
BytesIO(file_bytes), filename=filename or path.name
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 图片文件不存在: {path}")
|
||||
|
||||
if discord_file:
|
||||
files.append(discord_file)
|
||||
|
||||
except Exception:
|
||||
# 使用 getattr 来安全地访问 i.file,以防 i 本身就是问题
|
||||
file_info = getattr(i, "file", "未知")
|
||||
logger.error(
|
||||
f"[Discord] 处理图片时发生未知严重错误: {file_info}",
|
||||
exc_info=True,
|
||||
)
|
||||
elif isinstance(i, File):
|
||||
try:
|
||||
file_path_str = await i.get_file()
|
||||
if file_path_str:
|
||||
path = Path(file_path_str)
|
||||
if await asyncio.to_thread(path.exists):
|
||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||
files.append(
|
||||
discord.File(BytesIO(file_bytes),
|
||||
filename=i.name)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[Discord] 获取文件失败,路径不存在: {file_path_str}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Discord] 获取文件失败: {i.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}")
|
||||
elif isinstance(i, DiscordEmbed):
|
||||
# Discord Embed消息
|
||||
embeds.append(i.to_discord_embed())
|
||||
elif isinstance(i, DiscordView):
|
||||
# Discord视图组件(按钮、选择菜单等)
|
||||
view = i.to_discord_view()
|
||||
elif isinstance(i, DiscordViewComponent):
|
||||
# 如果消息链中包含Discord视图组件(兼容旧版本)
|
||||
if isinstance(i.view, discord.ui.View):
|
||||
view = i.view
|
||||
else:
|
||||
logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}")
|
||||
|
||||
if len(content) > 2000:
|
||||
logger.warning("[Discord] 消息内容超过2000字符,将被截断。")
|
||||
content = content[:2000]
|
||||
return content, files, view, embeds, reference_message_id
|
||||
|
||||
async def react(self, emoji: str):
|
||||
"""对原消息添加反应"""
|
||||
try:
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "add_reaction"
|
||||
):
|
||||
await self.message_obj.raw_message.add_reaction(emoji)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||
|
||||
def is_slash_command(self) -> bool:
|
||||
"""判断是否为斜杠命令"""
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type
|
||||
== discord.InteractionType.application_command
|
||||
)
|
||||
|
||||
def is_button_interaction(self) -> bool:
|
||||
"""判断是否为按钮交互"""
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
||||
)
|
||||
|
||||
def get_interaction_custom_id(self) -> str:
|
||||
"""获取交互组件的custom_id"""
|
||||
if self.is_button_interaction():
|
||||
try:
|
||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
def is_mentioned(self) -> bool:
|
||||
"""判断机器人是否被@"""
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "mentions"
|
||||
):
|
||||
return any(
|
||||
mention.id == int(self.message_obj.self_id)
|
||||
for mention in self.message_obj.raw_message.mentions
|
||||
)
|
||||
return False
|
||||
|
||||
def get_mention_clean_content(self) -> str:
|
||||
"""获取去除@后的清洁内容"""
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
self.message_obj.raw_message, "clean_content"
|
||||
):
|
||||
return self.message_obj.raw_message.clean_content
|
||||
return self.message_str
|
||||
@@ -1,754 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import quart
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SimpleGewechatClient:
|
||||
"""针对 Gewechat 的简单实现。
|
||||
|
||||
@author: Soulter
|
||||
@website: https://github.com/Soulter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
nickname: str,
|
||||
host: str,
|
||||
port: int,
|
||||
event_queue: asyncio.Queue,
|
||||
):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith("/"):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
|
||||
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"Gewechat API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/file/<file_id>",
|
||||
view_func=self._handle_file,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
self.userrealnames = {}
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def get_token_id(self):
|
||||
"""获取 Gewechat Token。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||
json_blob = await resp.json()
|
||||
self.token = json_blob["data"]
|
||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||
self.headers = {"X-GEWE-TOKEN": self.token}
|
||||
|
||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||
if "TypeName" in data:
|
||||
type_name = data["TypeName"]
|
||||
elif "type_name" in data:
|
||||
type_name = data["type_name"]
|
||||
else:
|
||||
raise Exception("无法识别的消息类型")
|
||||
|
||||
# 以下没有业务处理,只是避免控制台打印太多的日志
|
||||
if type_name == "ModContacts":
|
||||
logger.info("gewechat下发:ModContacts消息通知。")
|
||||
return
|
||||
if type_name == "DelContacts":
|
||||
logger.info("gewechat下发:DelContacts消息通知。")
|
||||
return
|
||||
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
|
||||
d = None
|
||||
if "Data" in data:
|
||||
d = data["Data"]
|
||||
elif "data" in data:
|
||||
d = data["data"]
|
||||
|
||||
if not d:
|
||||
logger.warning(f"消息不含 data 字段: {data}")
|
||||
return
|
||||
|
||||
if "CreateTime" in d:
|
||||
# 得到系统 UTF+8 的 ts
|
||||
tz_offset = datetime.timedelta(hours=8)
|
||||
tz = datetime.timezone(tz_offset)
|
||||
ts = datetime.datetime.now(tz).timestamp()
|
||||
create_time = d["CreateTime"]
|
||||
if create_time < ts - 30:
|
||||
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
|
||||
from_user_name = d["FromUserName"]["string"] # 消息来源
|
||||
d["to_wxid"] = from_user_name # 用于发信息
|
||||
|
||||
abm.message_id = str(d.get("MsgId"))
|
||||
abm.session_id = from_user_name
|
||||
abm.self_id = data["Wxid"] # 机器人的 wxid
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d["Content"]["string"] # 消息内容
|
||||
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(":\n")
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
if "\u2005" in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||
abm.group_id = from_user_name
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
if (
|
||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||
):
|
||||
at_me = True
|
||||
if "在群聊中@了你" in d.get("PushContent", ""):
|
||||
at_me = True
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
|
||||
# 检查消息是否由自己发送,若是则忽略
|
||||
if user_id == abm.self_id:
|
||||
logger.info("忽略自己发送的消息")
|
||||
return None
|
||||
|
||||
abm.message = []
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
if abm.group_id:
|
||||
if (
|
||||
abm.group_id not in self.userrealnames
|
||||
or user_id not in self.userrealnames[abm.group_id]
|
||||
):
|
||||
# 获取群成员列表,并且缓存
|
||||
if abm.group_id not in self.userrealnames:
|
||||
self.userrealnames[abm.group_id] = {}
|
||||
member_list = await self.get_chatroom_member_list(abm.group_id)
|
||||
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
|
||||
if member_list and "memberList" in member_list:
|
||||
for member in member_list["memberList"]:
|
||||
self.userrealnames[abm.group_id][member["wxid"]] = member[
|
||||
"nickName"
|
||||
]
|
||||
if user_id in self.userrealnames[abm.group_id]:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0]
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
|
||||
if user_id == "weixin":
|
||||
# 忽略微信团队消息
|
||||
return
|
||||
|
||||
# 不同消息类型
|
||||
match d["MsgType"]:
|
||||
case 1:
|
||||
# 文本消息
|
||||
abm.message.append(Plain(content))
|
||||
abm.message_str = content
|
||||
case 3:
|
||||
# 图片消息
|
||||
file_url = await self.multimedia_downloader.download_image(
|
||||
self.appid, content
|
||||
)
|
||||
logger.debug(f"下载图片: {file_url}")
|
||||
file_path = await download_image_by_url(file_url)
|
||||
abm.message.append(Image(file=file_path, url=file_path))
|
||||
|
||||
case 34:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
|
||||
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
|
||||
case 37: # 好友申请
|
||||
logger.info("消息类型(37):好友申请")
|
||||
case 42: # 名片
|
||||
logger.info("消息类型(42):名片")
|
||||
case 43: # 视频
|
||||
video = Video(file="", cover=content)
|
||||
abm.message.append(video)
|
||||
case 47: # emoji
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
emoji = data_parser.parse_emoji()
|
||||
abm.message.append(emoji)
|
||||
case 48: # 地理位置
|
||||
logger.info("消息类型(48):地理位置")
|
||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
abm_data = data_parser.parse_mutil_49()
|
||||
if abm_data:
|
||||
abm.message.append(abm_data)
|
||||
case 51: # 帐号消息同步?
|
||||
logger.info("消息类型(51):帐号消息同步?")
|
||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
|
||||
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
|
||||
logger.info(
|
||||
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
|
||||
)
|
||||
|
||||
case _:
|
||||
logger.info(f"未实现的消息类型: {d['MsgType']}")
|
||||
abm.raw_message = d
|
||||
|
||||
logger.debug(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
async def _callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
|
||||
if data.get("testMsg", None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = None
|
||||
try:
|
||||
abm = await self._convert(data)
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。"
|
||||
)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={"token": self.token, "callbackUrl": self.callback_url},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob["ret"] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(
|
||||
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
|
||||
)
|
||||
|
||||
async def start_polling(self):
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
await self.server.run_task(
|
||||
host="0.0.0.0",
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def check_online(self, appid: str):
|
||||
"""检查 APPID 对应的设备是否在线。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkOnline",
|
||||
headers=self.headers,
|
||||
json={"appId": appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def logout(self):
|
||||
"""登出 gewechat。"""
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/logout",
|
||||
headers=self.headers,
|
||||
json={"appId": self.appid},
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"登出结果: {json_blob}")
|
||||
|
||||
async def login(self):
|
||||
"""登录 gewechat。一般来说插件用不到这个方法。"""
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
self.multimedia_downloader = GeweDownloader(
|
||||
self.base_url, self.download_base_url, self.token
|
||||
)
|
||||
|
||||
if self.appid:
|
||||
try:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
logger.info(f"APPID: {self.appid} 已在线")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"检查在线状态失败: {e}")
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
if self.appid:
|
||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/getLoginQrCode",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob["ret"] != 200:
|
||||
error_msg = json_blob.get("data", {}).get("msg", "")
|
||||
if "设备不存在" in error_msg:
|
||||
logger.error(
|
||||
f"检测到无效的appid: {self.appid},将清除并重新登录。"
|
||||
)
|
||||
sp.put(f"gewechat-appid-{self.nickname}", "")
|
||||
self.appid = None
|
||||
return await self.login()
|
||||
else:
|
||||
raise Exception(f"获取二维码失败: {json_blob}")
|
||||
qr_data = json_blob["data"]["qrData"]
|
||||
qr_uuid = json_blob["data"]["uuid"]
|
||||
appid = json_blob["data"]["appId"]
|
||||
logger.info(f"APPID: {appid}")
|
||||
logger.warning(
|
||||
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# 执行登录
|
||||
retry_cnt = 64
|
||||
payload.update({"uuid": qr_uuid, "appId": appid})
|
||||
while retry_cnt > 0:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
if os.path.exists("data/temp/gewe_code"):
|
||||
with open("data/temp/gewe_code", "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove("data/temp/gewe_code")
|
||||
except Exception:
|
||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkLogin",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"检查登录状态: {json_blob}")
|
||||
|
||||
ret = json_blob["ret"]
|
||||
msg = ""
|
||||
if json_blob["data"] and "msg" in json_blob["data"]:
|
||||
msg = json_blob["data"]["msg"]
|
||||
if ret == 500 and "安全验证码" in msg:
|
||||
logger.warning(
|
||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
else:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{self.nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
|
||||
"""
|
||||
|
||||
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
|
||||
"""获取群成员列表。
|
||||
|
||||
Args:
|
||||
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
|
||||
|
||||
Returns:
|
||||
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
|
||||
"""
|
||||
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob["data"]
|
||||
|
||||
async def post_text(self, to_wxid, content: str, ats: str = ""):
|
||||
"""发送纯文本消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
if ats:
|
||||
payload["ats"] = ats
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postText", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送消息结果: {json_blob}")
|
||||
|
||||
async def post_image(self, to_wxid, image_url: str):
|
||||
"""发送图片消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"imgUrl": image_url,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
|
||||
"""发送emoji消息"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"emojiMd5": emoji_md5,
|
||||
"emojiSize": emoji_size,
|
||||
}
|
||||
|
||||
# 优先表情包,若拿不到表情包的md5,就用当作图片发
|
||||
try:
|
||||
if emoji_md5 != "" and emoji_size != "":
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postEmoji",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(
|
||||
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
|
||||
)
|
||||
else:
|
||||
await self.post_image(to_wxid, cdnurl)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
async def post_video(
|
||||
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"videoUrl": video_url,
|
||||
"thumbUrl": thumb_url,
|
||||
"videoDuration": video_duration,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送视频结果: {json_blob}")
|
||||
|
||||
async def forward_video(self, to_wxid, cnd_xml: str):
|
||||
"""转发视频
|
||||
|
||||
Args:
|
||||
to_wxid (str): 发送给谁
|
||||
cnd_xml (str): 视频消息的cdn信息
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"xml": cnd_xml,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/forwardVideo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"转发视频结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
"""发送语音信息
|
||||
|
||||
Args:
|
||||
voice_url (str): 语音文件的网络链接
|
||||
voice_duration (int): 语音时长,毫秒
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration,
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
|
||||
|
||||
async def post_file(self, to_wxid, file_url: str, file_name: str):
|
||||
"""发送文件
|
||||
|
||||
Args:
|
||||
to_wxid (string): 微信ID
|
||||
file_url (str): 文件的网络链接
|
||||
file_name (str): 文件名
|
||||
"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"fileUrl": file_url,
|
||||
"fileName": file_name,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送文件结果: {json_blob}")
|
||||
|
||||
async def add_friend(self, v3: str, v4: str, content: str):
|
||||
"""申请添加好友"""
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"scene": 3,
|
||||
"content": content,
|
||||
"v4": v4,
|
||||
"v3": v3,
|
||||
"option": 2,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/addContacts",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"申请添加好友结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_group_member(self, group_id: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/getChatroomMemberList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def accept_group_invite(self, url: str):
|
||||
"""同意进群"""
|
||||
payload = {"appId": self.appid, "url": url}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/agreeJoinRoom",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def add_group_member_to_friend(
|
||||
self, group_id: str, to_wxid: str, content: str
|
||||
):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"chatroomId": group_id,
|
||||
"content": content,
|
||||
"memberWxid": to_wxid,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/group/addGroupMemberAsFriend",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_user_or_group_info(self, *ids):
|
||||
"""
|
||||
获取用户或群组信息。
|
||||
|
||||
:param ids: 可变数量的 wxid 参数
|
||||
"""
|
||||
|
||||
wxids_str = list(ids)
|
||||
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"wxids": wxids_str, # 使用逗号分隔的字符串
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/getDetailInfo",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_contacts_list(self):
|
||||
"""
|
||||
获取通讯录列表
|
||||
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||
"""
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/fetchContactsList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||
return json_blob
|
||||
@@ -1,55 +0,0 @@
|
||||
from astrbot import logger
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
|
||||
class GeweDownloader:
|
||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.download_base_url = download_base_url
|
||||
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
|
||||
|
||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{baseurl}{route}", headers=self.headers, json=payload
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
|
||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||
|
||||
async def download_image(self, appid: str, xml: str) -> str:
|
||||
"""返回一个可下载的 URL"""
|
||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||
|
||||
for choice in choices:
|
||||
try:
|
||||
payload = {"appId": appid, "xml": xml, "type": choice}
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadImage", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
if "fileUrl" in json_blob["data"]:
|
||||
return self.download_base_url + json_blob["data"]["fileUrl"]
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download image: {e}")
|
||||
continue
|
||||
|
||||
raise Exception("无法下载图片")
|
||||
|
||||
async def download_emoji_md5(self, app_id, emoji_md5):
|
||||
"""下载emoji"""
|
||||
try:
|
||||
payload = {"appId": app_id, "emojiMd5": emoji_md5}
|
||||
|
||||
# gewe 计划中的接口,暂时没有实现。返回代码404
|
||||
data = await self._post_json(
|
||||
self.base_url, "/message/downloadEmojiMd5", payload
|
||||
)
|
||||
json_blob = json.loads(data)
|
||||
return json_blob
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download emoji: {e}")
|
||||
@@ -1,231 +0,0 @@
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.io import save_temp_img, download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
Record,
|
||||
At,
|
||||
File,
|
||||
Video,
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, "rb") as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
elif n_frames == 0:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleGewechatClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(
|
||||
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
|
||||
):
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
# 检查@
|
||||
ats = []
|
||||
ats_names = []
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, At):
|
||||
ats.append(comp.qq)
|
||||
ats_names.append(comp.name)
|
||||
has_at = False
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text = comp.text
|
||||
payload = {
|
||||
"to_wxid": to_wxid,
|
||||
"content": text,
|
||||
}
|
||||
if not has_at and ats:
|
||||
ats = f"{','.join(ats)}"
|
||||
ats_names = f"@{' @'.join(ats_names)}"
|
||||
text = f"{ats_names} {text}"
|
||||
payload["content"] = text
|
||||
payload["ats"] = ats
|
||||
has_at = True
|
||||
await client.post_text(**payload)
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中
|
||||
temp_directory = os.path.abspath("data/temp")
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Video):
|
||||
if comp.cover != "":
|
||||
await client.forward_video(to_wxid, comp.cover)
|
||||
else:
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
raise ModuleNotFoundError(
|
||||
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
|
||||
)
|
||||
|
||||
video_url = comp.file
|
||||
# 根据 url 下载视频
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
video_path = f"data/temp/{video_filename}"
|
||||
await download_file(video_url, video_path)
|
||||
|
||||
# 获取视频第一帧
|
||||
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
try:
|
||||
ff = FFmpeg()
|
||||
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
|
||||
ff.options(command)
|
||||
thumb_file_id = os.path.basename(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频第一帧失败: {e}")
|
||||
# 获取视频时长
|
||||
try:
|
||||
from pyffmpeg import FFprobe
|
||||
|
||||
# 创建 FFprobe 实例
|
||||
ffprobe = FFprobe(video_url)
|
||||
# 获取时长字符串
|
||||
duration_str = ffprobe.duration
|
||||
# 处理时长字符串
|
||||
video_duration = float(duration_str.replace(":", ""))
|
||||
except Exception as e:
|
||||
logger.error(f"获取时长失败: {e}")
|
||||
video_duration = 10
|
||||
|
||||
file_id = os.path.basename(video_path)
|
||||
video_url = f"{client.file_server_url}/{file_id}"
|
||||
await client.post_video(
|
||||
to_wxid, video_url, thumb_url, video_duration
|
||||
)
|
||||
|
||||
# 删除临时视频和缩略图文件
|
||||
if os.path.exists(video_path):
|
||||
os.remove(video_path)
|
||||
if os.path.exists(thumb_path):
|
||||
os.remove(thumb_path)
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||
elif isinstance(comp, File):
|
||||
file_path = comp.file
|
||||
file_name = comp.name
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
await download_file(file_path, f"data/temp/{file_name}")
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
file_id = os.path.basename(file_path)
|
||||
file_url = f"{client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await client.post_file(to_wxid, file_url, file_id)
|
||||
elif isinstance(comp, Emoji):
|
||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||
elif isinstance(comp, At):
|
||||
pass
|
||||
else:
|
||||
logger.debug(f"gewechat 忽略: {comp.type}")
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
|
||||
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
|
||||
await super().send(message)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
# 确定有效的 group_id
|
||||
if group_id is None:
|
||||
group_id = self.get_group_id()
|
||||
|
||||
if not group_id:
|
||||
return None
|
||||
|
||||
res = await self.client.get_group(group_id)
|
||||
data: dict = res["data"]
|
||||
|
||||
if not data["chatroomId"]:
|
||||
return None
|
||||
|
||||
members = [
|
||||
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
|
||||
for member in data.get("memberList", [])
|
||||
]
|
||||
|
||||
return Group(
|
||||
group_id=data["chatroomId"],
|
||||
group_name=data.get("nickName"),
|
||||
group_avatar=data.get("smallHeadImgUrl"),
|
||||
group_owner=data.get("chatRoomOwner"),
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
@@ -1,100 +0,0 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .gewechat_event import GewechatPlatformEvent
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||
class GewechatPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
|
||||
self.client = None
|
||||
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config["base_url"],
|
||||
self.config["nickname"],
|
||||
self.config["host"],
|
||||
self.config["port"],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
session_id = session.session_id
|
||||
if "#" in session_id:
|
||||
# unique session
|
||||
to_wxid = session_id.split("#")[1]
|
||||
else:
|
||||
to_wxid = session_id
|
||||
|
||||
await GewechatPlatformEvent.send_with_client(
|
||||
message_chain, to_wxid, self.client
|
||||
)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="gewechat",
|
||||
description="基于 gewechat 的 Wechat 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def terminate(self):
|
||||
self.client.shutdown_event.set()
|
||||
await self.client.server.shutdown()
|
||||
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
return self._run()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
await self.client.start_polling()
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if self.settingss["unique_session"]:
|
||||
message.session_id = message.sender.user_id + "#" + message.group_id
|
||||
|
||||
message_event = GewechatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self) -> SimpleGewechatClient:
|
||||
return self.client
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user