Compare commits
440 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9564166297 | ||
|
|
f5cf3c3c8e | ||
|
|
18f919fb6b | ||
|
|
0924835253 | ||
|
|
20d2e5c578 | ||
|
|
907801605c | ||
|
|
93bc684e8c | ||
|
|
a76c98d57e | ||
|
|
d937a800d0 | ||
|
|
d16f3a227f | ||
|
|
80c9a3eeda | ||
|
|
e68173b451 | ||
|
|
40c27d87f5 | ||
|
|
3c13b5049d | ||
|
|
8288d5e51f | ||
|
|
4ffbb18ab4 | ||
|
|
b27271b7a3 | ||
|
|
ebb6665f64 | ||
|
|
e4e5731ffd | ||
|
|
2ab5810f13 | ||
|
|
af934c5d09 | ||
|
|
1e0cf7c112 | ||
|
|
46859c93c9 | ||
|
|
1641549016 | ||
|
|
716a5dbb8a | ||
|
|
af98cb11c5 | ||
|
|
9a4c2cf341 | ||
|
|
2bc3bcd102 | ||
|
|
d6c663f79d | ||
|
|
2b4ee13b5e | ||
|
|
6959f86632 | ||
|
|
537d373e10 | ||
|
|
cceadf222c | ||
|
|
cf5a4af623 | ||
|
|
39aea11c22 | ||
|
|
c2f1227700 | ||
|
|
900f14d37c | ||
|
|
598249b1d6 | ||
|
|
7ed15bdf04 | ||
|
|
2fc0ec0f72 | ||
|
|
5e9c2a669b | ||
|
|
b310521884 | ||
|
|
288945bf7e | ||
|
|
4fc07cff36 | ||
|
|
b884fe0e86 | ||
|
|
855858c236 | ||
|
|
c11a2a5419 | ||
|
|
773a6572af | ||
|
|
88ad373c9b | ||
|
|
51666464b9 | ||
|
|
5af9cf2f52 | ||
|
|
12c4ae4b10 | ||
|
|
4e1bef414a | ||
|
|
e896c18644 | ||
|
|
c852685e74 | ||
|
|
1e99797df8 | ||
|
|
52a4c986a8 | ||
|
|
c501728204 | ||
|
|
6b067fa6a7 | ||
|
|
a1cd5c53a9 | ||
|
|
a46d487e03 | ||
|
|
3deb6d3ab3 | ||
|
|
af34cdd5d2 | ||
|
|
6e1393235a | ||
|
|
343e0b54b9 | ||
|
|
ecb70cb6f7 | ||
|
|
ca50618af6 | ||
|
|
29c07ba83e | ||
|
|
45fbb83a9f | ||
|
|
ae7ba2df25 | ||
|
|
c3ef57cc32 | ||
|
|
7bb4ca5a14 | ||
|
|
063783d81d | ||
|
|
42116c9b65 | ||
|
|
a36e11973d | ||
|
|
5125568ea2 | ||
|
|
0fa164e50d | ||
|
|
cf814e81ee | ||
|
|
43a45f18ce | ||
|
|
ad51381063 | ||
|
|
0b0e4ce904 | ||
|
|
6a3e04d688 | ||
|
|
4107a17370 | ||
|
|
06b4d8f169 | ||
|
|
1c0c820746 | ||
|
|
d061403a28 | ||
|
|
5c092321a6 | ||
|
|
bdd3f61c1f | ||
|
|
8023557d6e | ||
|
|
074b0ced7a | ||
|
|
3864b1ac9b | ||
|
|
6e9b43457d | ||
|
|
ca1aec8920 | ||
|
|
acac580862 | ||
|
|
673e1b2980 | ||
|
|
f62157be72 | ||
|
|
f894ecf3b6 | ||
|
|
66dd4e28ad | ||
|
|
939dc1b0fb | ||
|
|
56bf5d38a1 | ||
|
|
d09b70b295 | ||
|
|
205180387a | ||
|
|
39c8cfeda5 | ||
|
|
f38a329be5 | ||
|
|
a0cd069539 | ||
|
|
bf306a2f01 | ||
|
|
c31f93a8d1 | ||
|
|
4730ab6309 | ||
|
|
1ae78ca98c | ||
|
|
d2379da478 | ||
|
|
0f64981b20 | ||
|
|
0002e49bb5 | ||
|
|
db13a60274 | ||
|
|
db0f11a359 | ||
|
|
ac7f43520b | ||
|
|
f67b9f5f6e | ||
|
|
c75156c4ce | ||
|
|
10270b5595 | ||
|
|
f7458572ed | ||
|
|
d57b7222b2 | ||
|
|
62e70a673a | ||
|
|
5e9eba6478 | ||
|
|
cb02dfe1a4 | ||
|
|
b50739e1af | ||
|
|
8da1b0212d | ||
|
|
ca1f2acb33 | ||
|
|
c15f966669 | ||
|
|
7705b8781a | ||
|
|
b2502746f0 | ||
|
|
ab68094386 | ||
|
|
bbec701223 | ||
|
|
b29d14e600 | ||
|
|
86e51c5cd1 | ||
|
|
cb8267be3f | ||
|
|
eaed43915c | ||
|
|
bd91fd2c38 | ||
|
|
1203b214cd | ||
|
|
c3fec15f11 | ||
|
|
0545653494 | ||
|
|
db2989bdb4 | ||
|
|
587bd00a19 | ||
|
|
960ff438e8 | ||
|
|
98e7ea85d3 | ||
|
|
2549e44710 | ||
|
|
4d32b563ca | ||
|
|
3a4b732977 | ||
|
|
500909a28e | ||
|
|
07753eb25b | ||
|
|
c6eaf3d010 | ||
|
|
6723fe8271 | ||
|
|
3348b70435 | ||
|
|
35a8527c16 | ||
|
|
7afc475290 | ||
|
|
789bceaa3a | ||
|
|
abbc043969 | ||
|
|
654e5762f1 | ||
|
|
507c3e3629 | ||
|
|
991dfeb2f2 | ||
|
|
26482fc2d3 | ||
|
|
e0ce6d9688 | ||
|
|
946595216a | ||
|
|
864b6bc56d | ||
|
|
6ea5b7581f | ||
|
|
f70b8f0c10 | ||
|
|
1593bcb537 | ||
|
|
bf7fc02c8d | ||
|
|
143702b92b | ||
|
|
c5ccc1a084 | ||
|
|
2ecb52a9b2 | ||
|
|
6439917cbe | ||
|
|
d21c18f657 | ||
|
|
25ef0039e4 | ||
|
|
e6981290bc | ||
|
|
75c3d8abbd | ||
|
|
d88683f498 | ||
|
|
40b9aa3a4c | ||
|
|
b6d1515d58 | ||
|
|
e01d4264e3 | ||
|
|
2117b65487 | ||
|
|
a7823b352f | ||
|
|
c543b62a08 | ||
|
|
3923b87f08 | ||
|
|
b7ecdadb83 | ||
|
|
5ff121e1ed | ||
|
|
f486e5448f | ||
|
|
c5aae98558 | ||
|
|
6d8a3b9897 | ||
|
|
6d98780e19 | ||
|
|
3ad2c46f3f | ||
|
|
a730cee7fd | ||
|
|
77c823c100 | ||
|
|
124f21c67a | ||
|
|
e46cf20dd3 | ||
|
|
4bef5e8313 | ||
|
|
22e93b0af4 | ||
|
|
5aeca9662b | ||
|
|
b996cf1f05 | ||
|
|
878a106877 | ||
|
|
45d36f86fd | ||
|
|
b108ae403a | ||
|
|
887ed66768 | ||
|
|
dac840a887 | ||
|
|
238de4ba8c | ||
|
|
9a7bdade43 | ||
|
|
aa84556204 | ||
|
|
6b68069fcd | ||
|
|
42c7034fb2 | ||
|
|
060c7e0145 | ||
|
|
b5b085dfb1 | ||
|
|
fc06ce9d7f | ||
|
|
d8d81b05a7 | ||
|
|
a60f42b1f2 | ||
|
|
6e18be88d0 | ||
|
|
b45e439c48 | ||
|
|
b87061c18c | ||
|
|
f78aca7752 | ||
|
|
3ccca2aa10 | ||
|
|
6d7c40eb76 | ||
|
|
da4cd7fb65 | ||
|
|
c97cda6b84 | ||
|
|
7a7fd4167a | ||
|
|
dffc1a43d5 | ||
|
|
36897fea1e | ||
|
|
c7b34735f0 | ||
|
|
5b07176c88 | ||
|
|
474b40d660 | ||
|
|
a62901b948 | ||
|
|
25d8746327 | ||
|
|
aff1698223 | ||
|
|
7f8941745f | ||
|
|
b858401098 | ||
|
|
d5a158b80f | ||
|
|
f315f284aa | ||
|
|
c367f5009d | ||
|
|
6db1e63bda | ||
|
|
e22ab2ede6 | ||
|
|
b7d7e0b682 | ||
|
|
96bba15f2f | ||
|
|
fcf965a595 | ||
|
|
e1a20d3c22 | ||
|
|
2abd7d8c5d | ||
|
|
5b8f73cdd7 | ||
|
|
7fd765421f | ||
|
|
d9d94af022 | ||
|
|
790b924e57 | ||
|
|
4a62f877df | ||
|
|
ac47c57bb7 | ||
|
|
3ace4199a1 | ||
|
|
e6bd7524c1 | ||
|
|
699c86e8c1 | ||
|
|
f40fa0ecea | ||
|
|
626f94686b | ||
|
|
752d13b1b1 | ||
|
|
54c0dc1b2b | ||
|
|
c5bc709898 | ||
|
|
ccdbb01513 | ||
|
|
5206d750ac | ||
|
|
a800e3df67 | ||
|
|
ccb1f87a20 | ||
|
|
c111da4681 | ||
|
|
9cc4e97a53 | ||
|
|
dca1c0b0f3 | ||
|
|
f06be6ed21 | ||
|
|
3c8ec2f42e | ||
|
|
7e193f7f52 | ||
|
|
7069b02929 | ||
|
|
66995db927 | ||
|
|
c36054ca1b | ||
|
|
3e07fbf3dc | ||
|
|
bf3fbe3e96 | ||
|
|
0a93d22bc8 | ||
|
|
f5b3d94d16 | ||
|
|
4d1a6994aa | ||
|
|
05c686782c | ||
|
|
85609ea742 | ||
|
|
20dabc0615 | ||
|
|
356dd9bc2b | ||
|
|
cd5d7534c4 | ||
|
|
b4f12fc933 | ||
|
|
cbea387ce0 | ||
|
|
345b155374 | ||
|
|
29d216950e | ||
|
|
321b04772c | ||
|
|
5b924aee98 | ||
|
|
46d44e3405 | ||
|
|
4d5332fe25 | ||
|
|
18bd4c54f4 | ||
|
|
31c7768ca0 | ||
|
|
6ec643e9d1 | ||
|
|
2b39f6f61c | ||
|
|
bf3ca13961 | ||
|
|
82026370ec | ||
|
|
6d49bf5346 | ||
|
|
67431d87fb | ||
|
|
fdf55221e6 | ||
|
|
07f277dd3b | ||
|
|
cf8f0603ca | ||
|
|
5592408ab8 | ||
|
|
a01617b45c | ||
|
|
7abb4087b3 | ||
|
|
dff15cf27a | ||
|
|
aa858137e5 | ||
|
|
45cb143202 | ||
|
|
7a9c6ab8c4 | ||
|
|
e2c26c292d | ||
|
|
be7c3fd00e | ||
|
|
7e5461a2cf | ||
|
|
6ee9010645 | ||
|
|
a23d5be056 | ||
|
|
97a6a1fdc2 | ||
|
|
c8f567347b | ||
|
|
74c1e7f69e | ||
|
|
15a5fc0cae | ||
|
|
f07c54d47c | ||
|
|
70446be108 | ||
|
|
d6d21fca56 | ||
|
|
8d7273924f | ||
|
|
ea64afbaa7 | ||
|
|
45da9837ec | ||
|
|
8c19b7d163 | ||
|
|
ab227a08d0 | ||
|
|
40d6e77964 | ||
|
|
9326e3f1b0 | ||
|
|
0e1eb3daf6 | ||
|
|
05daac12ed | ||
|
|
c5b24b4764 | ||
|
|
cc16548e5f | ||
|
|
291d65bb3e | ||
|
|
bd3ad03da6 | ||
|
|
5fa6788357 | ||
|
|
c5c5a98ac4 | ||
|
|
a1151143cf | ||
|
|
f5024984f7 | ||
|
|
f4880fd90d | ||
|
|
0ae61d5865 | ||
|
|
d3bd775a79 | ||
|
|
da546cfe7f | ||
|
|
a211933e83 | ||
|
|
1d40b5a821 | ||
|
|
33836daeb7 | ||
|
|
d921b0f6bd | ||
|
|
0607b95df6 | ||
|
|
0de6d0e046 | ||
|
|
98427345cf | ||
|
|
9fedaa9f77 | ||
|
|
bf4c2ecd33 | ||
|
|
f8c18cc1e0 | ||
|
|
458b900412 | ||
|
|
192c776e0b | ||
|
|
5cdec18863 | ||
|
|
15f856f951 | ||
|
|
01d52cef74 | ||
|
|
95563c8659 | ||
|
|
31d8c40eca | ||
|
|
56001ed272 | ||
|
|
d916fda04c | ||
|
|
cfae655068 | ||
|
|
5596565ec4 | ||
|
|
afa1aa5d93 | ||
|
|
e98c3d8393 | ||
|
|
6687b816f0 | ||
|
|
ea8035e854 | ||
|
|
54b0171d49 | ||
|
|
676d4277b9 | ||
|
|
a4b1da3ca2 | ||
|
|
9e9c16e770 | ||
|
|
dc87006fed | ||
|
|
b9b260f26a | ||
|
|
33fd6a5016 | ||
|
|
97cbccc2ba | ||
|
|
1ee4685d5d | ||
|
|
aba18232b1 | ||
|
|
0a02441b75 | ||
|
|
1be5b4c7ff | ||
|
|
a0ce0cf18a | ||
|
|
7c54e5d093 | ||
|
|
b825e51dab | ||
|
|
589855c393 | ||
|
|
4c546f2f53 | ||
|
|
3753fce912 | ||
|
|
4c02857ec5 | ||
|
|
33f87ff7d7 | ||
|
|
784dcf2a9a | ||
|
|
43ee943acb | ||
|
|
a769fd7d13 | ||
|
|
2c4fd00b16 | ||
|
|
264771fe98 | ||
|
|
ecd92dafef | ||
|
|
c8b6e4bea3 | ||
|
|
3756cb766e | ||
|
|
068d9ca60b | ||
|
|
93f632d8b8 | ||
|
|
bb44ce7e74 | ||
|
|
6986c8d8f7 | ||
|
|
fe95506db4 | ||
|
|
310ed76b18 | ||
|
|
98830d147f | ||
|
|
19c9177d7b | ||
|
|
f41c5f97f6 | ||
|
|
648c125697 | ||
|
|
0dc2b89897 | ||
|
|
83745f83a5 | ||
|
|
2f91fe4535 | ||
|
|
739f09059e | ||
|
|
c86f9f0f5f | ||
|
|
9470ca6bc5 | ||
|
|
2a92c4d5de | ||
|
|
bb6e892657 | ||
|
|
c9079b9299 | ||
|
|
b6963c1bf9 | ||
|
|
9c29df47bb | ||
|
|
fc146d3d00 | ||
|
|
1bf5a21678 | ||
|
|
011542dc2b | ||
|
|
489784104e | ||
|
|
3860634fd2 | ||
|
|
709c324e18 | ||
|
|
b75d24d92c | ||
|
|
ed80e9424c | ||
|
|
2fe1f2060a | ||
|
|
c6df820164 | ||
|
|
7c1e8ce48c | ||
|
|
44dbe475af | ||
|
|
bd24cf3ea4 | ||
|
|
b493a808fe | ||
|
|
54035d108d | ||
|
|
c5e8bc7e20 | ||
|
|
3bbb4779a3 | ||
|
|
1b3963ebea | ||
|
|
e8ffebc006 | ||
|
|
2ca95eaa9f | ||
|
|
0dc5b4cdfc | ||
|
|
cc6cd96d8e | ||
|
|
4244d37625 | ||
|
|
0b766095d4 | ||
|
|
39693a27e3 | ||
|
|
b03fe438d0 | ||
|
|
7f56824b42 | ||
|
|
627da3a2bc | ||
|
|
e6b69042de |
@@ -17,4 +17,8 @@ ENV/
|
||||
.conda/
|
||||
README*.md
|
||||
dashboard/
|
||||
data/
|
||||
data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
.astrbot
|
||||
15
.github/FUNDING.yml
vendored
Normal file
15
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: astrbot
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
polar: # Replace with a single Polar username
|
||||
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||
thanks_dev: # Replace with a single thanks.dev username
|
||||
custom: ['https://afdian.com/a/astrbot_team']
|
||||
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,5 +1,5 @@
|
||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||
修复了 #XYZ
|
||||
解决了 #XYZ
|
||||
|
||||
### Motivation
|
||||
|
||||
@@ -10,5 +10,10 @@
|
||||
<!--简单解释你的改动-->
|
||||
|
||||
### Check
|
||||
- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 我新增/修复/优化的功能经过良好的测试
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
||||
|
||||
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 👀 我的更改经过良好的测试
|
||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
||||
- [ ] 😮 我的更改没有引入恶意代码
|
||||
|
||||
33
.github/workflows/auto_release.yml
vendored
33
.github/workflows/auto_release.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
name: Auto Release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
build-and-publish-to-github-release:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -28,8 +28,35 @@ jobs:
|
||||
run: |
|
||||
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Create Release
|
||||
- name: Create GitHub Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
bodyFile: ${{ env.changelog }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
artifacts: "dashboard/dist.zip"
|
||||
|
||||
build-and-publish-to-pypi:
|
||||
# 构建并发布到 PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
python -m pip install uv
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
uv build
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
uv publish
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -30,3 +30,4 @@ packages/python_interpreter/workplace
|
||||
.conda/
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
@@ -4,6 +4,8 @@ WORKDIR /AstrBot
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
nodejs \
|
||||
npm \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
@@ -28,3 +30,6 @@ EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
|
||||
|
||||
|
||||
61
README.md
61
README.md
@@ -13,11 +13,11 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||

|
||||
|
||||

|
||||

|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
@@ -27,11 +27,14 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
|
||||
[](https://gitcode.com/Soulter/AstrBot)
|
||||
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,`v3.5.10` 已经支持接入 WeChatPadPro 替换 gewechat 方式。详见文档 [WeChatPadPro](https://astrbot.app/deploy/platform/wechat/wechatpadpro.html)
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
@@ -75,14 +78,29 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
#### 手动部署
|
||||
|
||||
推荐使用 `uv`。
|
||||
> 推荐使用 `uv`。
|
||||
|
||||
首先,安装 uv:
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
通过 Git Clone 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
pip install uv
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者,直接通过 uvx 安装 AstrBot:
|
||||
|
||||
```bash
|
||||
mkdir astrbot && cd astrbot
|
||||
uvx astrbot init
|
||||
# uvx astrbot run
|
||||
```
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### Replit 部署
|
||||
@@ -95,9 +113,10 @@ uv run main.py
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(官方机器人接口) | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 微信个人号 | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
|
||||
| Telegram | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 企业微信 | ✔ | 私聊 | 文字、图片、语音 |
|
||||
| 微信客服 | ✔ | 私聊 | 文字、图片 |
|
||||
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
@@ -109,21 +128,26 @@ uv run main.py
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||
| Claude API | ✔ | 文本生成 | |
|
||||
| Google Gemini API | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| DashScope(阿里云百炼应用) | ✔ | LLMOps | |
|
||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -151,6 +175,8 @@ pre-commit install
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
@@ -172,6 +198,9 @@ _✨ WebUI ✨_
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
@@ -180,6 +209,10 @@ _✨ WebUI ✨_
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
此外,本项目的诞生离不开以下开源项目:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ)
|
||||
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
|
||||
1
astrbot/cli/__init__.py
Normal file
1
astrbot/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "3.5.8"
|
||||
59
astrbot/cli/__main__.py
Normal file
59
astrbot/cli/__main__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
AstrBot CLI入口
|
||||
"""
|
||||
|
||||
import click
|
||||
import sys
|
||||
from . import __version__
|
||||
from .commands import init, run, plug, conf
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(__version__, prog_name="AstrBot")
|
||||
def cli() -> None:
|
||||
"""The AstrBot CLI"""
|
||||
click.echo(logo_tmpl)
|
||||
click.echo("Welcome to AstrBot CLI!")
|
||||
click.echo(f"AstrBot CLI version: {__version__}")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("command_name", required=False, type=str)
|
||||
def help(command_name: str | None) -> None:
|
||||
"""显示命令的帮助信息
|
||||
|
||||
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
|
||||
否则,显示通用帮助信息。
|
||||
"""
|
||||
ctx = click.get_current_context()
|
||||
if command_name:
|
||||
# 查找指定命令
|
||||
command = cli.get_command(ctx, command_name)
|
||||
if command:
|
||||
# 显示特定命令的帮助信息
|
||||
click.echo(command.get_help(ctx))
|
||||
else:
|
||||
click.echo(f"Unknown command: {command_name}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# 显示通用帮助信息
|
||||
click.echo(cli.get_help(ctx))
|
||||
|
||||
|
||||
cli.add_command(init)
|
||||
cli.add_command(run)
|
||||
cli.add_command(help)
|
||||
cli.add_command(plug)
|
||||
cli.add_command(conf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
6
astrbot/cli/commands/__init__.py
Normal file
6
astrbot/cli/commands/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .cmd_init import init
|
||||
from .cmd_run import run
|
||||
from .cmd_plug import plug
|
||||
from .cmd_conf import conf
|
||||
|
||||
__all__ = ["init", "run", "plug", "conf"]
|
||||
206
astrbot/cli/commands/cmd_conf.py
Normal file
206
astrbot/cli/commands/cmd_conf.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import json
|
||||
import click
|
||||
import hashlib
|
||||
import zoneinfo
|
||||
from typing import Any, Callable
|
||||
from ..utils import get_astrbot_root, check_astrbot_root
|
||||
|
||||
|
||||
def _validate_log_level(value: str) -> str:
|
||||
"""验证日志级别"""
|
||||
value = value.upper()
|
||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise click.ClickException(
|
||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_port(value: str) -> int:
|
||||
"""验证 Dashboard 端口"""
|
||||
try:
|
||||
port = int(value)
|
||||
if port < 1 or port > 65535:
|
||||
raise click.ClickException("端口必须在 1-65535 范围内")
|
||||
return port
|
||||
except ValueError:
|
||||
raise click.ClickException("端口必须是数字")
|
||||
|
||||
|
||||
def _validate_dashboard_username(value: str) -> str:
|
||||
"""验证 Dashboard 用户名"""
|
||||
if not value:
|
||||
raise click.ClickException("用户名不能为空")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_dashboard_password(value: str) -> str:
|
||||
"""验证 Dashboard 密码"""
|
||||
if not value:
|
||||
raise click.ClickException("密码不能为空")
|
||||
return hashlib.md5(value.encode()).hexdigest()
|
||||
|
||||
|
||||
def _validate_timezone(value: str) -> str:
|
||||
"""验证时区"""
|
||||
try:
|
||||
zoneinfo.ZoneInfo(value)
|
||||
except Exception:
|
||||
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
|
||||
return value
|
||||
|
||||
|
||||
def _validate_callback_api_base(value: str) -> str:
|
||||
"""验证回调接口基址"""
|
||||
if not value.startswith("http://") and not value.startswith("https://"):
|
||||
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
|
||||
return value
|
||||
|
||||
|
||||
# 可通过CLI设置的配置项,配置键到验证器函数的映射
|
||||
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
|
||||
"timezone": _validate_timezone,
|
||||
"log_level": _validate_log_level,
|
||||
"dashboard.port": _validate_dashboard_port,
|
||||
"dashboard.username": _validate_dashboard_username,
|
||||
"dashboard.password": _validate_dashboard_password,
|
||||
"callback_api_base": _validate_callback_api_base,
|
||||
}
|
||||
|
||||
|
||||
def _load_config() -> dict[str, Any]:
|
||||
"""加载或初始化配置文件"""
|
||||
root = get_astrbot_root()
|
||||
if not check_astrbot_root(root):
|
||||
raise click.ClickException(
|
||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
config_path = root / "data" / "cmd_config.json"
|
||||
if not config_path.exists():
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.ClickException(f"配置文件解析失败: {str(e)}")
|
||||
|
||||
|
||||
def _save_config(config: dict[str, Any]) -> None:
|
||||
"""保存配置文件"""
|
||||
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
||||
)
|
||||
|
||||
|
||||
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""设置嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts[:-1]:
|
||||
if part not in obj:
|
||||
obj[part] = {}
|
||||
elif not isinstance(obj[part], dict):
|
||||
raise click.ClickException(
|
||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
|
||||
)
|
||||
obj = obj[part]
|
||||
obj[parts[-1]] = value
|
||||
|
||||
|
||||
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
"""获取嵌套字典中的值"""
|
||||
parts = path.split(".")
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj
|
||||
|
||||
|
||||
@click.group(name="conf")
|
||||
def conf():
|
||||
"""配置管理命令
|
||||
|
||||
支持的配置项:
|
||||
|
||||
- timezone: 时区设置 (例如: Asia/Shanghai)
|
||||
|
||||
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
||||
|
||||
- dashboard.port: Dashboard 端口
|
||||
|
||||
- dashboard.username: Dashboard 用户名
|
||||
|
||||
- dashboard.password: Dashboard 密码
|
||||
|
||||
- callback_api_base: 回调接口基址
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str):
|
||||
"""设置配置项的值"""
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
config = _load_config()
|
||||
|
||||
try:
|
||||
old_value = _get_nested_item(config, key)
|
||||
validated_value = CONFIG_VALIDATORS[key](value)
|
||||
_set_nested_item(config, key, validated_value)
|
||||
_save_config(config)
|
||||
|
||||
click.echo(f"配置已更新: {key}")
|
||||
if key == "dashboard.password":
|
||||
click.echo(" 原值: ********")
|
||||
click.echo(" 新值: ********")
|
||||
else:
|
||||
click.echo(f" 原值: {old_value}")
|
||||
click.echo(f" 新值: {validated_value}")
|
||||
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"设置配置失败: {str(e)}")
|
||||
|
||||
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str = None):
|
||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||
config = _load_config()
|
||||
|
||||
if key:
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
try:
|
||||
value = _get_nested_item(config, key)
|
||||
if key == "dashboard.password":
|
||||
value = "********"
|
||||
click.echo(f"{key}: {value}")
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"获取配置失败: {str(e)}")
|
||||
else:
|
||||
click.echo("当前配置:")
|
||||
for key in CONFIG_VALIDATORS.keys():
|
||||
try:
|
||||
value = (
|
||||
"********"
|
||||
if key == "dashboard.password"
|
||||
else _get_nested_item(config, key)
|
||||
)
|
||||
click.echo(f" {key}: {value}")
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
55
astrbot/cli/commands/cmd_init.py
Normal file
55
astrbot/cli/commands/cmd_init.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def initialize_astrbot(astrbot_root) -> None:
|
||||
"""执行 AstrBot 初始化逻辑"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(f"Current Directory: {astrbot_root}")
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||
)
|
||||
if click.confirm(
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
|
||||
paths = {
|
||||
"data": astrbot_root / "data",
|
||||
"config": astrbot_root / "data" / "config",
|
||||
"plugins": astrbot_root / "data" / "plugins",
|
||||
"temp": astrbot_root / "data" / "temp",
|
||||
}
|
||||
|
||||
for name, path in paths.items():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
|
||||
@click.command()
|
||||
def init() -> None:
|
||||
"""初始化 AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
astrbot_root = get_astrbot_root()
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
|
||||
try:
|
||||
with lock.acquire():
|
||||
asyncio.run(initialize_astrbot(astrbot_root))
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"初始化失败: {e!s}")
|
||||
247
astrbot/cli/commands/cmd_plug.py
Normal file
247
astrbot/cli/commands/cmd_plug.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import shutil
|
||||
|
||||
|
||||
from ..utils import (
|
||||
get_git_repo,
|
||||
build_plug_list,
|
||||
manage_plugin,
|
||||
PluginStatus,
|
||||
check_astrbot_root,
|
||||
get_astrbot_root,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
def plug():
|
||||
"""插件管理"""
|
||||
pass
|
||||
|
||||
|
||||
def _get_data_path() -> Path:
|
||||
base = get_astrbot_root()
|
||||
if not check_astrbot_root(base):
|
||||
raise click.ClickException(
|
||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
return (base / "data").resolve()
|
||||
|
||||
|
||||
def display_plugins(plugins, title=None, color=None):
|
||||
if title:
|
||||
click.echo(click.style(title, fg=color, bold=True))
|
||||
|
||||
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
|
||||
click.echo("-" * 85)
|
||||
|
||||
for p in plugins:
|
||||
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
||||
click.echo(
|
||||
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
||||
f"{p['author']:<15} {desc:<30}"
|
||||
)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def new(name: str):
|
||||
"""创建新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
|
||||
if plug_path.exists():
|
||||
raise click.ClickException(f"插件 {name} 已存在")
|
||||
|
||||
author = click.prompt("请输入插件作者", type=str)
|
||||
desc = click.prompt("请输入插件描述", type=str)
|
||||
version = click.prompt("请输入插件版本", type=str)
|
||||
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
|
||||
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
|
||||
repo = click.prompt("请输入插件仓库:", type=str)
|
||||
if not repo.startswith("http"):
|
||||
raise click.ClickException("仓库地址必须以 http 开头")
|
||||
|
||||
click.echo("下载插件模板...")
|
||||
get_git_repo(
|
||||
"https://github.com/Soulter/helloworld",
|
||||
plug_path,
|
||||
)
|
||||
|
||||
click.echo("重写插件信息...")
|
||||
# 重写 metadata.yaml
|
||||
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"name: {name}\n"
|
||||
f"desc: {desc}\n"
|
||||
f"version: {version}\n"
|
||||
f"author: {author}\n"
|
||||
f"repo: {repo}\n"
|
||||
)
|
||||
|
||||
# 重写 README.md
|
||||
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
|
||||
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||
|
||||
# 重写 main.py
|
||||
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
new_content = content.replace(
|
||||
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
|
||||
f'@register("{name}", "{author}", "{desc}", "{version}")',
|
||||
)
|
||||
|
||||
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
click.echo(f"插件 {name} 创建成功")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
|
||||
def list(all: bool):
|
||||
"""列出插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
# 未发布的插件
|
||||
not_published_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
|
||||
]
|
||||
if not_published_plugins:
|
||||
display_plugins(not_published_plugins, "未发布的插件", "red")
|
||||
|
||||
# 需要更新的插件
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
if need_update_plugins:
|
||||
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
|
||||
|
||||
# 已安装的插件
|
||||
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
|
||||
if installed_plugins:
|
||||
display_plugins(installed_plugins, "已安装的插件", "green")
|
||||
|
||||
# 未安装的插件
|
||||
not_installed_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
|
||||
]
|
||||
if not_installed_plugins and all:
|
||||
display_plugins(not_installed_plugins, "未安装的插件", "blue")
|
||||
|
||||
if (
|
||||
not any([not_published_plugins, need_update_plugins, installed_plugins])
|
||||
and not all
|
||||
):
|
||||
click.echo("未安装任何插件")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
@click.option("--proxy", help="代理服务器地址")
|
||||
def install(name: str, proxy: str | None):
|
||||
"""安装插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def remove(name: str):
|
||||
"""卸载插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
plugin = next((p for p in plugins if p["name"] == name), None)
|
||||
|
||||
if not plugin or not plugin.get("local_path"):
|
||||
raise click.ClickException(f"插件 {name} 不存在或未安装")
|
||||
|
||||
plugin_path = plugin["local_path"]
|
||||
|
||||
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
|
||||
|
||||
try:
|
||||
shutil.rmtree(plugin_path)
|
||||
click.echo(f"插件 {name} 已卸载")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name", required=False)
|
||||
@click.option("--proxy", help="Github代理地址")
|
||||
def update(name: str, proxy: str | None):
|
||||
"""更新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
if name:
|
||||
plugin = next(
|
||||
(
|
||||
p
|
||||
for p in plugins
|
||||
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not plugin:
|
||||
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
|
||||
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
else:
|
||||
need_update_plugins = [
|
||||
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
|
||||
]
|
||||
|
||||
if not need_update_plugins:
|
||||
click.echo("没有需要更新的插件")
|
||||
return
|
||||
|
||||
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
|
||||
for plugin in need_update_plugins:
|
||||
plugin_name = plugin["name"]
|
||||
click.echo(f"正在更新插件 {plugin_name}...")
|
||||
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
|
||||
|
||||
|
||||
@plug.command()
|
||||
@click.argument("query")
|
||||
def search(query: str):
|
||||
"""搜索插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
matched_plugins = [
|
||||
p
|
||||
for p in plugins
|
||||
if query.lower() in p["name"].lower()
|
||||
or query.lower() in p["desc"].lower()
|
||||
or query.lower() in p["author"].lower()
|
||||
]
|
||||
|
||||
if not matched_plugins:
|
||||
click.echo(f"未找到匹配 '{query}' 的插件")
|
||||
return
|
||||
|
||||
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
|
||||
63
astrbot/cli/commands/cmd_run.py
Normal file
63
astrbot/cli/commands/cmd_run.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import logger, LogManager, LogBroker, db_helper
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
db = db_helper
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
|
||||
await core_lifecycle.start()
|
||||
|
||||
|
||||
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
|
||||
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
|
||||
@click.command()
|
||||
def run(reload: bool, port: str) -> None:
|
||||
"""运行 AstrBot"""
|
||||
try:
|
||||
os.environ["ASTRBOT_CLI"] = "1"
|
||||
astrbot_root = get_astrbot_root()
|
||||
|
||||
if not check_astrbot_root(astrbot_root):
|
||||
raise click.ClickException(
|
||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
)
|
||||
|
||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||
sys.path.insert(0, str(astrbot_root))
|
||||
|
||||
if port:
|
||||
os.environ["DASHBOARD_PORT"] = port
|
||||
|
||||
if reload:
|
||||
click.echo("启用插件自动重载")
|
||||
os.environ["ASTRBOT_RELOAD"] = "1"
|
||||
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
with lock.acquire():
|
||||
asyncio.run(run_astrbot(astrbot_root))
|
||||
except KeyboardInterrupt:
|
||||
click.echo("AstrBot 已关闭...")
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
except Exception as e:
|
||||
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
|
||||
18
astrbot/cli/utils/__init__.py
Normal file
18
astrbot/cli/utils/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from .basic import (
|
||||
get_astrbot_root,
|
||||
check_astrbot_root,
|
||||
check_dashboard,
|
||||
)
|
||||
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
__all__ = [
|
||||
"get_astrbot_root",
|
||||
"check_astrbot_root",
|
||||
"check_dashboard",
|
||||
"get_git_repo",
|
||||
"manage_plugin",
|
||||
"build_plug_list",
|
||||
"VersionComparator",
|
||||
"PluginStatus",
|
||||
]
|
||||
67
astrbot/cli/utils/basic.py
Normal file
67
astrbot/cli/utils/basic.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def check_astrbot_root(path: str | Path) -> bool:
|
||||
"""检查路径是否为 AstrBot 根目录"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if not path.exists() or not path.is_dir():
|
||||
return False
|
||||
if not (path / ".astrbot").exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""获取Astrbot根目录路径"""
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
"""检查是否安装了dashboard"""
|
||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||
from astrbot.core.config.default import VERSION
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
try:
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("未安装管理面板")
|
||||
if click.confirm(
|
||||
"是否安装管理面板?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在安装管理面板...")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
case str():
|
||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||
click.echo("管理面板已是最新版本")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("初始化管理面板目录...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
230
astrbot/cli/utils/plugin.py
Normal file
230
astrbot/cli/utils/plugin.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import click
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
class PluginStatus(str, Enum):
|
||||
INSTALLED = "已安装"
|
||||
NEED_UPDATE = "需更新"
|
||||
NOT_INSTALLED = "未安装"
|
||||
NOT_PUBLISHED = "未发布"
|
||||
|
||||
|
||||
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
"""从 Git 仓库下载代码并解压到指定路径"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
# 解析仓库信息
|
||||
repo_namespace = url.split("/")[-2:]
|
||||
author = repo_namespace[0]
|
||||
repo = repo_namespace[1]
|
||||
|
||||
# 尝试获取最新的 release
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
try:
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(release_url)
|
||||
resp.raise_for_status()
|
||||
releases = resp.json()
|
||||
|
||||
if releases:
|
||||
# 使用最新的 release
|
||||
download_url = releases[0]["zipball_url"]
|
||||
else:
|
||||
# 没有 release,使用默认分支
|
||||
click.echo(f"正在从默认分支下载 {author}/{repo}")
|
||||
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
except Exception as e:
|
||||
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
|
||||
download_url = url
|
||||
|
||||
# 应用代理
|
||||
if proxy:
|
||||
download_url = f"{proxy}/{download_url}"
|
||||
|
||||
# 下载并解压
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
) as client:
|
||||
resp = client.get(download_url)
|
||||
if (
|
||||
resp.status_code == 404
|
||||
and "archive/refs/heads/master.zip" in download_url
|
||||
):
|
||||
alt_url = download_url.replace("master.zip", "main.zip")
|
||||
click.echo("master 分支不存在,尝试下载 main 分支")
|
||||
resp = client.get(alt_url)
|
||||
resp.raise_for_status()
|
||||
else:
|
||||
resp.raise_for_status()
|
||||
zip_content = BytesIO(resp.content)
|
||||
with ZipFile(zip_content) as z:
|
||||
z.extractall(temp_dir)
|
||||
namelist = z.namelist()
|
||||
root_dir = Path(namelist[0]).parts[0] if namelist else ""
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path)
|
||||
shutil.move(temp_dir / root_dir, target_path)
|
||||
finally:
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
"""从 metadata.yaml 文件加载插件元数据
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录路径
|
||||
|
||||
Returns:
|
||||
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||
"""
|
||||
yaml_path = plugin_dir / "metadata.yaml"
|
||||
if yaml_path.exists():
|
||||
try:
|
||||
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
|
||||
except Exception as e:
|
||||
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
|
||||
return {}
|
||||
|
||||
|
||||
def build_plug_list(plugins_dir: Path) -> list:
|
||||
"""构建插件列表,包含本地和在线插件信息
|
||||
|
||||
Args:
|
||||
plugins_dir (Path): 插件目录路径
|
||||
|
||||
Returns:
|
||||
list: 包含插件信息的字典列表
|
||||
"""
|
||||
# 获取本地插件信息
|
||||
result = []
|
||||
if plugins_dir.exists():
|
||||
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
|
||||
plugin_dir = plugins_dir / plugin_name
|
||||
|
||||
# 从 metadata.yaml 加载元数据
|
||||
metadata = load_yaml_metadata(plugin_dir)
|
||||
|
||||
# 如果成功加载元数据,添加到结果列表
|
||||
if metadata and all(
|
||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||
):
|
||||
result.append({
|
||||
"name": str(metadata.get("name", "")),
|
||||
"desc": str(metadata.get("desc", "")),
|
||||
"version": str(metadata.get("version", "")),
|
||||
"author": str(metadata.get("author", "")),
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
})
|
||||
|
||||
# 获取在线插件列表
|
||||
online_plugins = []
|
||||
try:
|
||||
with httpx.Client() as client:
|
||||
resp = client.get("https://api.soulter.top/astrbot/plugins")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
for plugin_id, plugin_info in data.items():
|
||||
online_plugins.append({
|
||||
"name": str(plugin_id),
|
||||
"desc": str(plugin_info.get("desc", "")),
|
||||
"version": str(plugin_info.get("version", "")),
|
||||
"author": str(plugin_info.get("author", "")),
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
})
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
|
||||
# 与在线插件比对,更新状态
|
||||
online_plugin_names = {plugin["name"] for plugin in online_plugins}
|
||||
for local_plugin in result:
|
||||
if local_plugin["name"] in online_plugin_names:
|
||||
# 查找对应的在线插件
|
||||
online_plugin = next(
|
||||
p for p in online_plugins if p["name"] == local_plugin["name"]
|
||||
)
|
||||
if (
|
||||
VersionComparator.compare_version(
|
||||
local_plugin["version"], online_plugin["version"]
|
||||
)
|
||||
< 0
|
||||
):
|
||||
local_plugin["status"] = PluginStatus.NEED_UPDATE
|
||||
else:
|
||||
# 本地插件未在线上发布
|
||||
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
|
||||
|
||||
# 添加未安装的在线插件
|
||||
for online_plugin in online_plugins:
|
||||
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
|
||||
result.append(online_plugin)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def manage_plugin(
|
||||
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
|
||||
) -> None:
|
||||
"""安装或更新插件
|
||||
|
||||
Args:
|
||||
plugin (dict): 插件信息字典
|
||||
plugins_dir (Path): 插件目录
|
||||
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||
proxy (str, optional): 代理服务器地址
|
||||
"""
|
||||
plugin_name = plugin["name"]
|
||||
repo_url = plugin["repo"]
|
||||
|
||||
# 如果是更新且有本地路径,直接使用本地路径
|
||||
if is_update and plugin.get("local_path"):
|
||||
target_path = Path(plugin["local_path"])
|
||||
else:
|
||||
target_path = plugins_dir / plugin_name
|
||||
|
||||
backup_path = Path(f"{target_path}_backup") if is_update else None
|
||||
|
||||
# 检查插件是否存在
|
||||
if is_update and not target_path.exists():
|
||||
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||
|
||||
# 备份现有插件
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
if is_update:
|
||||
shutil.copytree(target_path, backup_path)
|
||||
|
||||
try:
|
||||
click.echo(
|
||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
||||
)
|
||||
get_git_repo(repo_url, target_path, proxy)
|
||||
|
||||
# 更新成功,删除备份
|
||||
if is_update and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||
except Exception as e:
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
if is_update and backup_path.exists():
|
||||
shutil.move(backup_path, target_path)
|
||||
raise click.ClickException(
|
||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
||||
)
|
||||
92
astrbot/cli/utils/version_comparator.py
Normal file
92
astrbot/cli/utils/version_comparator.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
拷贝自 astrbot.core.utils.version_comparator
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class VersionComparator:
|
||||
@staticmethod
|
||||
def compare_version(v1: str, v2: str) -> int:
|
||||
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
|
||||
|
||||
参考: https://semver.org/lang/zh-CN/
|
||||
|
||||
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||
"""
|
||||
v1 = v1.lower().replace("v", "")
|
||||
v2 = v2.lower().replace("v", "")
|
||||
|
||||
def split_version(version):
|
||||
match = re.match(
|
||||
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
|
||||
version,
|
||||
)
|
||||
if not match:
|
||||
return [], None
|
||||
major_minor_patch = match.group(1).split(".")
|
||||
prerelease = match.group(2)
|
||||
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
|
||||
parts = [int(x) for x in major_minor_patch]
|
||||
prerelease = VersionComparator._split_prerelease(prerelease)
|
||||
return parts, prerelease
|
||||
|
||||
v1_parts, v1_prerelease = split_version(v1)
|
||||
v2_parts, v2_prerelease = split_version(v2)
|
||||
|
||||
# 比较数字部分
|
||||
length = max(len(v1_parts), len(v2_parts))
|
||||
v1_parts.extend([0] * (length - len(v1_parts)))
|
||||
v2_parts.extend([0] * (length - len(v2_parts)))
|
||||
|
||||
for i in range(length):
|
||||
if v1_parts[i] > v2_parts[i]:
|
||||
return 1
|
||||
elif v1_parts[i] < v2_parts[i]:
|
||||
return -1
|
||||
|
||||
# 比较预发布标签
|
||||
if v1_prerelease is None and v2_prerelease is not None:
|
||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is None:
|
||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is not None:
|
||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||
for i in range(len_pre):
|
||||
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
|
||||
|
||||
if p1 is None and p2 is not None:
|
||||
return -1
|
||||
elif p1 is not None and p2 is None:
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, str):
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
return 0 # 预发布标签完全相同
|
||||
|
||||
return 0 # 数字部分和预发布标签都相同
|
||||
|
||||
@staticmethod
|
||||
def _split_prerelease(prerelease):
|
||||
if not prerelease:
|
||||
return None
|
||||
parts = prerelease.split(".")
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
result.append(int(part))
|
||||
else:
|
||||
result.append(part)
|
||||
return result
|
||||
@@ -7,24 +7,28 @@ from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.file_token_service import FileTokenService
|
||||
from .utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs("data", exist_ok=True)
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
|
||||
html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = (
|
||||
SharedPreferences()
|
||||
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences()
|
||||
# 文件令牌服务
|
||||
file_token_service = FileTokenService()
|
||||
pip_installer = PipInstaller(
|
||||
astrbot_config.get("pip_install_arg", ""),
|
||||
astrbot_config.get("pypi_index_url", None),
|
||||
)
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -45,8 +46,6 @@ class AstrBotConfig(dict):
|
||||
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith("/ufeff"): # remove BOM
|
||||
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.5.3.2"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.13"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
@@ -38,6 +41,7 @@ DEFAULT_CONFIG = {
|
||||
"no_permission_reply": True,
|
||||
"empty_mention_waiting": True,
|
||||
"friend_message_needs_wake_prefix": False,
|
||||
"ignore_bot_self_message": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
@@ -52,6 +56,7 @@ DEFAULT_CONFIG = {
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"streaming_segmented": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -60,6 +65,8 @@ DEFAULT_CONFIG = {
|
||||
"provider_tts_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
"use_file_service": False,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
@@ -85,6 +92,7 @@ DEFAULT_CONFIG = {
|
||||
"t2i_word_threshold": 150,
|
||||
"t2i_strategy": "remote",
|
||||
"t2i_endpoint": "",
|
||||
"t2i_use_file_service": False,
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
@@ -97,10 +105,11 @@ DEFAULT_CONFIG = {
|
||||
"wake_prefix": ["/"],
|
||||
"log_level": "INFO",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": "",
|
||||
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
|
||||
"knowledge_db": {},
|
||||
"persona": [],
|
||||
"timezone": "",
|
||||
"callback_api_base": "",
|
||||
}
|
||||
|
||||
|
||||
@@ -137,6 +146,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "0.0.0.0",
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
@@ -147,6 +157,29 @@ CONFIG_METADATA_2 = {
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 11451,
|
||||
},
|
||||
"wechatpadpro(微信)": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
"weixin_official_account(微信公众平台)": {
|
||||
"id": "weixin_official_account",
|
||||
"type": "weixin_official_account",
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6194,
|
||||
"active_send_mode": False,
|
||||
},
|
||||
"wecom(企业微信)": {
|
||||
"id": "wecom",
|
||||
"type": "wecom",
|
||||
@@ -155,6 +188,7 @@ CONFIG_METADATA_2 = {
|
||||
"secret": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"kf_name": "",
|
||||
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6195,
|
||||
@@ -183,19 +217,57 @@ CONFIG_METADATA_2 = {
|
||||
"start_message": "Hello, I'm AstrBot!",
|
||||
"telegram_api_base_url": "https://api.telegram.org/bot",
|
||||
"telegram_file_base_url": "https://api.telegram.org/file/bot",
|
||||
"telegram_command_register": True,
|
||||
"telegram_command_auto_refresh": True,
|
||||
"telegram_command_register_interval": 300,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"active_send_mode": {
|
||||
"description": "是否换用主动发送接口",
|
||||
"type": "bool",
|
||||
"desc": "只有企业认证的公众号才能主动发送。主动发送接口的限制会少一些。",
|
||||
},
|
||||
"wpp_active_message_poll": {
|
||||
"description": "是否启用主动消息轮询",
|
||||
"type": "bool",
|
||||
"hint": "只有当你发现微信消息没有按时同步到 AstrBot 时,才需要启用这个功能,默认不启用。",
|
||||
},
|
||||
"wpp_active_message_poll_interval": {
|
||||
"description": "主动消息轮询间隔",
|
||||
"type": "int",
|
||||
"hint": "主动消息轮询间隔,单位为秒,默认 3 秒,最大不要超过 60 秒,否则可能被认为是旧消息。",
|
||||
},
|
||||
"kf_name": {
|
||||
"description": "微信客服账号名",
|
||||
"type": "string",
|
||||
"hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取",
|
||||
},
|
||||
"telegram_token": {
|
||||
"description": "Bot Token",
|
||||
"type": "string",
|
||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||
},
|
||||
"telegram_command_register": {
|
||||
"description": "Telegram 命令注册",
|
||||
"type": "bool",
|
||||
"hint": "启用后,AstrBot 将会自动注册 Telegram 命令。",
|
||||
},
|
||||
"telegram_command_auto_refresh": {
|
||||
"description": "Telegram 命令自动刷新",
|
||||
"type": "bool",
|
||||
"hint": "启用后,AstrBot 将会在运行时自动刷新 Telegram 命令。(单独设置此项无效)",
|
||||
},
|
||||
"telegram_command_register_interval": {
|
||||
"description": "Telegram 命令自动刷新间隔",
|
||||
"type": "int",
|
||||
"hint": "Telegram 命令自动刷新间隔,单位为秒。",
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"description": "机器人名称",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。",
|
||||
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
|
||||
},
|
||||
"type": {
|
||||
"description": "适配器类型",
|
||||
@@ -215,7 +287,7 @@ CONFIG_METADATA_2 = {
|
||||
"secret": {
|
||||
"description": "secret",
|
||||
"type": "string",
|
||||
"hint": "必填项。QQ 官方机器人平台的 secret。如何获取请参考文档。",
|
||||
"hint": "必填项。",
|
||||
},
|
||||
"enable_group_c2c": {
|
||||
"description": "启用消息列表单聊",
|
||||
@@ -237,6 +309,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||
},
|
||||
"ws_reverse_token": {
|
||||
"description": "反向 Websocket Token",
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
},
|
||||
"lark_bot_name": {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
@@ -287,6 +364,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,私聊消息需要唤醒前缀才会被处理,同群聊一样。",
|
||||
},
|
||||
"ignore_bot_self_message": {
|
||||
"description": "是否忽略机器人自身的消息",
|
||||
"type": "bool",
|
||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
@@ -443,6 +525,7 @@ CONFIG_METADATA_2 = {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
@@ -451,9 +534,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"Azure_OpenAI": {
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
@@ -463,9 +547,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"xAI(grok)": {
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
@@ -474,9 +559,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "grok-2-latest",
|
||||
},
|
||||
},
|
||||
"Anthropic(claude)": {
|
||||
"Anthropic": {
|
||||
"id": "claude",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
@@ -489,6 +575,7 @@ CONFIG_METADATA_2 = {
|
||||
"Ollama": {
|
||||
"id": "ollama_default",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
@@ -496,9 +583,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "llama3.1-8b",
|
||||
},
|
||||
},
|
||||
"LM_Studio": {
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://localhost:1234/v1",
|
||||
@@ -509,6 +597,7 @@ CONFIG_METADATA_2 = {
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
@@ -517,9 +606,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
},
|
||||
"Gemini(googlegenai原生)": {
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
@@ -528,16 +618,22 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
"gm_safety_settings": {
|
||||
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"budget": 0,
|
||||
},
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
@@ -546,9 +642,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-chat",
|
||||
},
|
||||
},
|
||||
"Zhipu(智谱)": {
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
@@ -557,9 +654,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
},
|
||||
"SiliconFlow(硅基流动)": {
|
||||
"硅基流动": {
|
||||
"id": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
@@ -568,9 +666,10 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
},
|
||||
},
|
||||
"MoonShot(Kimi)": {
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
@@ -579,9 +678,22 @@ CONFIG_METADATA_2 = {
|
||||
"model": "moonshot-v1-8k",
|
||||
},
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"id": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
},
|
||||
},
|
||||
"LLMTuner": {
|
||||
"id": "llmtuner_default",
|
||||
"type": "llm_tuner",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"base_model_path": "",
|
||||
"adapter_model_path": "",
|
||||
@@ -592,6 +704,7 @@ CONFIG_METADATA_2 = {
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
"type": "dify",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"dify_api_type": "chat",
|
||||
"dify_api_key": "",
|
||||
@@ -601,9 +714,10 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"Dashscope(阿里云百炼应用)": {
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"type": "dashscope",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"dashscope_app_type": "agent",
|
||||
"dashscope_api_key": "",
|
||||
@@ -619,6 +733,7 @@ CONFIG_METADATA_2 = {
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
@@ -627,6 +742,7 @@ CONFIG_METADATA_2 = {
|
||||
"Whisper(API)": {
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_api",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
@@ -634,22 +750,25 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"Whisper(本地加载)": {
|
||||
"whisper_hint": "(不用修改我)",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
},
|
||||
"sensevoice(本地加载)": {
|
||||
"SenseVoice(本地加载)": {
|
||||
"sensevoice_hint": "(不用修改我)",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
"enable": False,
|
||||
"id": "sensevoice",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"stt_model": "iic/SenseVoiceSmall",
|
||||
"is_emotion": False,
|
||||
},
|
||||
"OpenAI_TTS(API)": {
|
||||
"OpenAI TTS(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
@@ -657,43 +776,218 @@ CONFIG_METADATA_2 = {
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
},
|
||||
"Edge_TTS": {
|
||||
"Edge TTS": {
|
||||
"edgetts_hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"type": "edge_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
|
||||
"timeout": 20,
|
||||
},
|
||||
"GSVI_TTS(API)": {
|
||||
"GSVI TTS(API)": {
|
||||
"id": "gsvi_tts",
|
||||
"type": "gsvi_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"api_base": "http://127.0.0.1:5000",
|
||||
"character": "",
|
||||
"emotion": "default",
|
||||
"enable": False,
|
||||
"timeout": 20,
|
||||
},
|
||||
"FishAudio_TTS(API)": {
|
||||
"FishAudio TTS(API)": {
|
||||
"id": "fishaudio_tts",
|
||||
"type": "fishaudio_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "https://api.fish.audio/v1",
|
||||
"fishaudio-tts-character": "可莉",
|
||||
"timeout": "20",
|
||||
},
|
||||
"阿里云百炼_TTS(API)": {
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"id": "dashscope_tts",
|
||||
"type": "dashscope_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"model": "cosyvoice-v1",
|
||||
"dashscope_tts_voice": "loongstella",
|
||||
"timeout": "20",
|
||||
},
|
||||
"Azure TTS": {
|
||||
"id": "azure_tts",
|
||||
"type": "azure_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": True,
|
||||
"azure_tts_voice": "zh-CN-YunxiaNeural",
|
||||
"azure_tts_style": "cheerful",
|
||||
"azure_tts_role": "Boy",
|
||||
"azure_tts_rate": "1",
|
||||
"azure_tts_volume": "100",
|
||||
"azure_tts_subscription_key": "",
|
||||
"azure_tts_region": "eastus",
|
||||
},
|
||||
"MiniMax TTS(API)": {
|
||||
"id": "minimax_tts",
|
||||
"type": "minimax_tts_api",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "https://api.minimax.chat/v1/t2a_v2",
|
||||
"minimax-group-id": "",
|
||||
"model": "speech-02-turbo",
|
||||
"minimax-langboost": "auto",
|
||||
"minimax-voice-speed": 1.0,
|
||||
"minimax-voice-vol": 1.0,
|
||||
"minimax-voice-pitch": 0,
|
||||
"minimax-is-timber-weight": False,
|
||||
"minimax-voice-id": "female-shaonv",
|
||||
"minimax-timber-weight": '[\n {\n "voice_id": "Chinese (Mandarin)_Warm_Girl",\n "weight": 25\n },\n {\n "voice_id": "Chinese (Mandarin)_BashfulGirl",\n "weight": 50\n }\n]',
|
||||
"minimax-voice-emotion": "neutral",
|
||||
"minimax-voice-latex": False,
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
},
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
"type": "volcengine_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"appid": "",
|
||||
"volcengine_cluster": "volcano_tts",
|
||||
"volcengine_voice_type": "",
|
||||
"volcengine_speed_ratio": 1.0,
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": 20,
|
||||
},
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
"type": "openai_embedding",
|
||||
"provider_type": "embedding",
|
||||
"enable": True,
|
||||
"embedding_api_key": "",
|
||||
"embedding_api_base": "",
|
||||
"embedding_model": "",
|
||||
"embedding_dimensions": 1536,
|
||||
"timeout": 20,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"embedding_dimensions": {
|
||||
"description": "嵌入维度",
|
||||
"type": "int",
|
||||
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
|
||||
},
|
||||
"embedding_model": {
|
||||
"description": "嵌入模型",
|
||||
"type": "string",
|
||||
"hint": "嵌入模型名称。",
|
||||
},
|
||||
"embedding_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
"hint": "API Key",
|
||||
},
|
||||
"volcengine_cluster": {
|
||||
"type": "string",
|
||||
"description": "火山引擎集群",
|
||||
"hint": "若使用语音复刻大模型,可选volcano_icl或volcano_icl_concurr,默认使用volcano_tts",
|
||||
},
|
||||
"volcengine_voice_type": {
|
||||
"type": "string",
|
||||
"description": "火山引擎音色",
|
||||
"hint": "输入声音id(Voice_type)",
|
||||
},
|
||||
"volcengine_speed_ratio": {
|
||||
"type": "float",
|
||||
"description": "语速设置",
|
||||
"hint": "语速设置,范围为 0.2 到 3.0,默认值为 1.0",
|
||||
},
|
||||
"volcengine_volume_ratio": {
|
||||
"type": "float",
|
||||
"description": "音量设置",
|
||||
"hint": "音量设置,范围为 0.0 到 2.0,默认值为 1.0",
|
||||
},
|
||||
"azure_tts_voice": {
|
||||
"type": "string",
|
||||
"description": "音色设置",
|
||||
"hint": "API 音色",
|
||||
},
|
||||
"azure_tts_style": {
|
||||
"type": "string",
|
||||
"description": "风格设置",
|
||||
"hint": "声音特定的讲话风格。 可以表达快乐、同情和平静等情绪。",
|
||||
},
|
||||
"azure_tts_role": {
|
||||
"type": "string",
|
||||
"description": "模仿设置(可选)",
|
||||
"hint": "讲话角色扮演。 声音可以模仿不同的年龄和性别,但声音名称不会更改。 例如,男性语音可以提高音调和改变语调来模拟女性语音,但语音名称不会更改。 如果角色缺失或不受声音的支持,则会忽略此属性。",
|
||||
"options": [
|
||||
"Boy",
|
||||
"Girl",
|
||||
"YoungAdultFemale",
|
||||
"YoungAdultMale",
|
||||
"OlderAdultFemale",
|
||||
"OlderAdultMale",
|
||||
"SeniorFemale",
|
||||
"SeniorMale",
|
||||
"禁用",
|
||||
],
|
||||
},
|
||||
"azure_tts_rate": {
|
||||
"type": "string",
|
||||
"description": "语速设置",
|
||||
"hint": "指示文本的讲出速率。可在字词或句子层面应用语速。 速率变化应为原始音频的 0.5 到 2 倍。",
|
||||
},
|
||||
"azure_tts_volume": {
|
||||
"type": "string",
|
||||
"description": "语音音量设置",
|
||||
"hint": "指示语音的音量级别。 可在句子层面应用音量的变化。以从 0.0 到 100.0(从最安静到最大声,例如 75)的数字表示。 默认值为 100.0。",
|
||||
},
|
||||
"azure_tts_region": {
|
||||
"type": "string",
|
||||
"description": "API 地区",
|
||||
"hint": "Azure_TTS 处理数据所在区域,具体参考 https://learn.microsoft.com/zh-cn/azure/ai-services/speech-service/regions",
|
||||
"options": [
|
||||
"southafricanorth",
|
||||
"eastasia",
|
||||
"southeastasia",
|
||||
"australiaeast",
|
||||
"centralindia",
|
||||
"japaneast",
|
||||
"japanwest",
|
||||
"koreacentral",
|
||||
"canadacentral",
|
||||
"northeurope",
|
||||
"westeurope",
|
||||
"francecentral",
|
||||
"germanywestcentral",
|
||||
"norwayeast",
|
||||
"swedencentral",
|
||||
"switzerlandnorth",
|
||||
"switzerlandwest",
|
||||
"uksouth",
|
||||
"uaenorth",
|
||||
"brazilsouth",
|
||||
"qatarcentral",
|
||||
"centralus",
|
||||
"eastus",
|
||||
"eastus2",
|
||||
"northcentralus",
|
||||
"southcentralus",
|
||||
"westcentralus",
|
||||
"westus",
|
||||
"westus2",
|
||||
"westus3",
|
||||
],
|
||||
},
|
||||
"azure_tts_subscription_key": {
|
||||
"type": "string",
|
||||
"description": "服务订阅密钥",
|
||||
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
|
||||
},
|
||||
"dashscope_tts_voice": {
|
||||
"description": "语音合成模型",
|
||||
"type": "string",
|
||||
@@ -704,6 +998,18 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
|
||||
},
|
||||
"gm_native_search": {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_native_coderunner": {
|
||||
"description": "启用原生代码执行器",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
"type": "object",
|
||||
@@ -755,6 +1061,109 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"description": "Gemini思考设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "思考预算",
|
||||
"type": "int",
|
||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"minimax-group-id": {
|
||||
"type": "string",
|
||||
"description": "用户组",
|
||||
"hint": "于账户管理->基本信息中可见",
|
||||
},
|
||||
"minimax-langboost": {
|
||||
"type": "string",
|
||||
"description": "指定语言/方言",
|
||||
"hint": "增强对指定的小语种和方言的识别能力,设置后可以提升在指定小语种/方言场景下的语音表现",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"Chinese,Yue",
|
||||
"English",
|
||||
"Arabic",
|
||||
"Russian",
|
||||
"Spanish",
|
||||
"French",
|
||||
"Portuguese",
|
||||
"German",
|
||||
"Turkish",
|
||||
"Dutch",
|
||||
"Ukrainian",
|
||||
"Vietnamese",
|
||||
"Indonesian",
|
||||
"Japanese",
|
||||
"Italian",
|
||||
"Korean",
|
||||
"Thai",
|
||||
"Polish",
|
||||
"Romanian",
|
||||
"Greek",
|
||||
"Czech",
|
||||
"Finnish",
|
||||
"Hindi",
|
||||
"auto",
|
||||
],
|
||||
},
|
||||
"minimax-voice-speed": {
|
||||
"type": "float",
|
||||
"description": "语速",
|
||||
"hint": "生成声音的语速, 取值[0.5, 2], 默认为1.0, 取值越大,语速越快",
|
||||
},
|
||||
"minimax-voice-vol": {
|
||||
"type": "float",
|
||||
"description": "音量",
|
||||
"hint": "生成声音的音量, 取值(0, 10], 默认为1.0, 取值越大,音量越高",
|
||||
},
|
||||
"minimax-voice-pitch": {
|
||||
"type": "int",
|
||||
"description": "语调",
|
||||
"hint": "生成声音的语调, 取值[-12, 12], 默认为0",
|
||||
},
|
||||
"minimax-is-timber-weight": {
|
||||
"type": "bool",
|
||||
"description": "启用混合音色",
|
||||
"hint": "启用混合音色, 支持以自定义权重混合最多四种音色, 启用后自动忽略单一音色设置",
|
||||
},
|
||||
"minimax-timber-weight": {
|
||||
"type": "string",
|
||||
"description": "混合音色",
|
||||
"editor_mode": True,
|
||||
"hint": "混合音色及其权重, 最多支持四种音色, 权重为整数, 取值[1, 100]. 可在官网API语音调试台预览代码获得预设以及编写模板, 需要严格按照json字符串格式编写, 可以查看控制台判断是否解析成功. 具体结构可参照默认值以及官网代码预览.",
|
||||
},
|
||||
"minimax-voice-id": {
|
||||
"type": "string",
|
||||
"description": "单一音色",
|
||||
"hint": "单一音色编号, 详见官网文档",
|
||||
},
|
||||
"minimax-voice-emotion": {
|
||||
"type": "string",
|
||||
"description": "情绪",
|
||||
"hint": "控制合成语音的情绪",
|
||||
"options": [
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"fearful",
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
],
|
||||
},
|
||||
"minimax-voice-latex": {
|
||||
"type": "bool",
|
||||
"description": "支持朗读latex公式",
|
||||
"hint": "朗读latex公式, 但是需要确保输入文本按官网要求格式化",
|
||||
},
|
||||
"minimax-voice-english-normalization": {
|
||||
"type": "bool",
|
||||
"description": "支持英语文本规范化",
|
||||
"hint": "可提升数字阅读场景的性能,但会略微增加延迟",
|
||||
},
|
||||
"rag_options": {
|
||||
"description": "RAG 选项",
|
||||
"type": "object",
|
||||
@@ -852,7 +1261,12 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商类型",
|
||||
"description": "模型提供商种类",
|
||||
"type": "string",
|
||||
"invisible": True,
|
||||
},
|
||||
"provider_type": {
|
||||
"description": "模型提供商能力种类",
|
||||
"type": "string",
|
||||
"invisible": True,
|
||||
},
|
||||
@@ -1008,6 +1422,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"description": "不支持流式回复的平台分段输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
@@ -1082,6 +1501,17 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
"dual_output": {
|
||||
"description": "启用语音和文字双输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"use_file_service": {
|
||||
"description": "使用文件服务提供 TTS 语音文件",
|
||||
"type": "bool",
|
||||
"hint": "启用后,如已配置 callback_api_base ,将会使用文件服务提供TTS语音文件",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -1194,6 +1624,12 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||
},
|
||||
"callback_api_base": {
|
||||
"description": "对外可达的回调接口地址",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。",
|
||||
},
|
||||
"log_level": {
|
||||
"description": "控制台日志级别",
|
||||
"type": "string",
|
||||
@@ -1211,21 +1647,20 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
|
||||
},
|
||||
"t2i_use_file_service": {
|
||||
"description": "本地文本转图像使用文件服务提供文件",
|
||||
"type": "bool",
|
||||
"hint": "当 t2i_strategy 为 local 并且配置 callback_api_base 时生效。是否使用文件服务提供文件。",
|
||||
},
|
||||
"pip_install_arg": {
|
||||
"description": "pip 安装参数",
|
||||
"type": "string",
|
||||
"hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。",
|
||||
},
|
||||
"plugin_repo_mirror": {
|
||||
"description": "插件仓库镜像",
|
||||
"pypi_index_url": {
|
||||
"description": "PyPI 软件仓库地址",
|
||||
"type": "string",
|
||||
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
|
||||
"obvious_hint": True,
|
||||
"options": [
|
||||
"default",
|
||||
"https://ghp.ci/",
|
||||
"https://github-mirror.us.kg/",
|
||||
],
|
||||
"hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 https://mirrors.aliyun.com/pypi/simple/",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
|
||||
工作流程:
|
||||
@@ -28,7 +28,6 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
@@ -37,7 +36,7 @@ from astrbot.core.star.star_handler import star_map
|
||||
class AstrBotCoreLifecycle:
|
||||
"""
|
||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
EventBus 等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
"""
|
||||
@@ -54,7 +53,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
|
||||
# 初始化日志代理
|
||||
@@ -73,9 +72,6 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
@@ -87,7 +83,6 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -106,7 +101,7 @@ class AstrBotCoreLifecycle:
|
||||
await self.pipeline_scheduler.initialize()
|
||||
|
||||
# 初始化更新器
|
||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
|
||||
self.astrbot_updator = AstrBotUpdator()
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
import json
|
||||
import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
|
||||
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
"""插件数据的 SQLite 存储实现类。
|
||||
|
||||
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||
"""
|
||||
|
||||
_instance = None # Standalone instance of the class
|
||||
_db_conn = None
|
||||
db_path = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
创建或获取 SQLitePluginStorage 的单例实例。
|
||||
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||
cls._instance.db_path = DBPATH
|
||||
return cls._instance
|
||||
|
||||
async def _init_db(self):
|
||||
"""初始化数据库连接(只执行一次)"""
|
||||
if SQLitePluginStorage._db_conn is None:
|
||||
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||
await self._setup_db()
|
||||
|
||||
async def _setup_db(self):
|
||||
"""
|
||||
异步初始化数据库。
|
||||
|
||||
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||
其中 plugin 和 key 组合作为主键。
|
||||
"""
|
||||
await self._db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||
plugin TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
PRIMARY KEY (plugin, key)
|
||||
)
|
||||
""")
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def set(self, plugin: str, key: str, value: Any):
|
||||
"""
|
||||
异步存储数据。
|
||||
|
||||
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||
值会被序列化为 JSON 字符串后存储。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, json.dumps(value)),
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def get(self, plugin: str, key: str) -> Any:
|
||||
"""
|
||||
异步获取数据。
|
||||
|
||||
从数据库中获取指定插件和键名对应的值,
|
||||
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
|
||||
Returns:
|
||||
Any: 存储的数据值,如果未找到则返回 None
|
||||
"""
|
||||
await self._init_db()
|
||||
async with self._db_conn.execute(
|
||||
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
|
||||
async def delete(self, plugin: str, key: str):
|
||||
"""
|
||||
异步删除数据。
|
||||
|
||||
从数据库中删除指定插件和键名对应的数据项。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 要删除的数据键名
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
46
astrbot/core/db/vec_db/base.py
Normal file
46
astrbot/core/db/vec_db/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
similarity: float
|
||||
data: dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化向量数据库
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
top_k (int): 返回的最相似文档的数量
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, doc_id: str) -> bool:
|
||||
"""
|
||||
删除指定文档。
|
||||
Args:
|
||||
doc_id (str): 要删除的文档 ID
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
3
astrbot/core/db/vec_db/faiss_impl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vec_db import FaissVecDB
|
||||
|
||||
__all__ = ["FaissVecDB"]
|
||||
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
121
astrbot/core/db/vec_db/faiss_impl/document_storage.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
|
||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||
"""Retrieve documents by metadata filters and ids.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
"""
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
)
|
||||
await self.connection.commit()
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
Args:
|
||||
row (tuple): The row to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
"doc_id": row[1],
|
||||
"text": row[2],
|
||||
"metadata": row[3],
|
||||
"created_at": row[4],
|
||||
"updated_at": row[5],
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
59
astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
try:
|
||||
import faiss
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
||||
)
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str = None):
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
if path and os.path.exists(path):
|
||||
self.index = faiss.read_index(path)
|
||||
else:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
self.storage = {}
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
"""插入向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 要插入的向量
|
||||
id (int): 向量的ID
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
self.storage[id] = vector
|
||||
await self.save_index()
|
||||
|
||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||
"""搜索最相似的向量
|
||||
|
||||
Args:
|
||||
vector (np.ndarray): 查询向量
|
||||
k (int): 返回的最相似向量的数量
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
"""
|
||||
faiss.normalize_L2(vector)
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引
|
||||
|
||||
Args:
|
||||
path (str): 保存索引的路径
|
||||
"""
|
||||
faiss.write_index(self.index, self.path)
|
||||
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
17
astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
-- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at
|
||||
CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
doc_id TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
|
||||
ALTER TABLE documents
|
||||
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
|
||||
|
||||
CREATE INDEX idx_documents_user_id ON documents(user_id);
|
||||
CREATE INDEX idx_documents_group_id ON documents(group_id);
|
||||
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
117
astrbot/core/db/vec_db/faiss_impl/vec_db.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import uuid
|
||||
import json
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
A class to represent a vector database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_store_path: str,
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
):
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
self.embedding_provider = embedding_provider
|
||||
self.document_storage = DocumentStorage(doc_store_path)
|
||||
self.embedding_storage = EmbeddingStorage(
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
|
||||
async def initialize(self):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def retrieve(
|
||||
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
k (int): 返回的最相似文档的数量
|
||||
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||
metadata_filters (dict): 元数据过滤器
|
||||
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
"""
|
||||
embedding = await self.embedding_provider.get_embedding(query)
|
||||
scores, indices = await self.embedding_storage.search(
|
||||
vector=np.array([embedding]).astype("float32"),
|
||||
k=fetch_k if metadata_filters else k,
|
||||
)
|
||||
# TODO: rerank
|
||||
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||
return []
|
||||
# normalize scores
|
||||
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||
# NOTE: maybe the size is less than k.
|
||||
fetched_docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters or {}, ids=indices[0]
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
result_docs = []
|
||||
|
||||
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||
for i, indice_idx in enumerate(indices[0]):
|
||||
pos = idx_pos.get(indice_idx)
|
||||
if pos is None:
|
||||
continue
|
||||
fetch_doc = fetched_docs[pos]
|
||||
score = scores[0][i]
|
||||
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||
return result_docs[:k]
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
"""
|
||||
删除一条文档
|
||||
"""
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
"""
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
68
astrbot/core/file_token_service.py
Normal file
68
astrbot/core/file_token_service.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||
|
||||
def __init__(self, default_timeout: float = 300):
|
||||
self.lock = asyncio.Lock()
|
||||
self.staged_files = {} # token: (file_path, expire_time)
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
async def _cleanup_expired_tokens(self):
|
||||
"""清理过期的令牌"""
|
||||
now = time.time()
|
||||
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
|
||||
for token in expired_tokens:
|
||||
self.staged_files.pop(token, None)
|
||||
|
||||
async def register_file(self, file_path: str, timeout: float = None) -> str:
|
||||
"""向令牌服务注册一个文件。
|
||||
|
||||
Args:
|
||||
file_path(str): 文件路径
|
||||
timeout(float): 超时时间,单位秒(可选)
|
||||
|
||||
Returns:
|
||||
str: 一个单次令牌
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 当路径不存在时抛出
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
|
||||
self.staged_files[file_token] = (file_path, expire_time)
|
||||
return file_token
|
||||
|
||||
async def handle_file(self, file_token: str) -> str:
|
||||
"""根据令牌获取文件路径,使用后令牌失效。
|
||||
|
||||
Args:
|
||||
file_token(str): 注册时返回的令牌
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
|
||||
Raises:
|
||||
KeyError: 当令牌不存在或已过期时抛出
|
||||
FileNotFoundError: 当文件本身已被删除时抛出
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
if file_token not in self.staged_files:
|
||||
raise KeyError(f"无效或过期的文件 token: {file_token}")
|
||||
|
||||
file_path, _ = self.staged_files.pop(file_token)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
return file_path
|
||||
@@ -26,13 +26,14 @@ class InitialLoader:
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
|
||||
core_task = []
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
except Exception as e:
|
||||
logger.critical(traceback.format_exc())
|
||||
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
|
||||
return
|
||||
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||
|
||||
@@ -25,6 +25,7 @@ import logging
|
||||
import colorlog
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
@@ -171,7 +172,9 @@ class LogManager:
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
# 如果logger没有处理器
|
||||
console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出
|
||||
console_handler = logging.StreamHandler(
|
||||
sys.stdout
|
||||
) # 创建一个StreamHandler用于控制台输出
|
||||
console_handler.setLevel(
|
||||
logging.DEBUG
|
||||
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||
|
||||
@@ -22,14 +22,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import typing as T
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from pydantic.v1 import BaseModel
|
||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
||||
|
||||
from astrbot.core import astrbot_config, file_token_service, logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||
|
||||
|
||||
class ComponentType(Enum):
|
||||
@@ -97,6 +102,10 @@ class BaseMessageComponent(BaseModel):
|
||||
data[k] = v
|
||||
return {"type": self.type.lower(), "data": data}
|
||||
|
||||
async def to_dict(self) -> dict:
|
||||
# 默认情况下,回退到旧的同步 toDict()
|
||||
return self.toDict()
|
||||
|
||||
|
||||
class Plain(BaseMessageComponent):
|
||||
type: ComponentType = "Plain"
|
||||
@@ -113,6 +122,9 @@ class Plain(BaseMessageComponent):
|
||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||
)
|
||||
|
||||
def toDict(self):
|
||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
@@ -165,7 +177,8 @@ class Record(BaseMessageComponent):
|
||||
elif self.file and self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
@@ -193,8 +206,32 @@ class Record(BaseMessageComponent):
|
||||
bs64_data = file_to_base64(self.file)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将语音注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Video(BaseMessageComponent):
|
||||
type: ComponentType = "Video"
|
||||
@@ -205,9 +242,6 @@ class Video(BaseMessageComponent):
|
||||
path: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
# for k in _.keys():
|
||||
# if k == "c" and _[k] not in [2, 3]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -220,6 +254,70 @@ class Video(BaseMessageComponent):
|
||||
return Video(file=url, **_)
|
||||
raise Exception("not a valid url")
|
||||
|
||||
async def convert_to_file_path(self) -> str:
|
||||
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。
|
||||
|
||||
Returns:
|
||||
str: 视频的本地路径,以绝对路径表示。
|
||||
"""
|
||||
url = self.file
|
||||
if url and url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url and url.startswith("http"):
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(url, video_file_path)
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
else:
|
||||
raise Exception(f"download failed: {url}")
|
||||
elif os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将视频注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = self.file
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated video file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "video",
|
||||
"data": {
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type: ComponentType = "At"
|
||||
@@ -229,6 +327,12 @@ class At(BaseMessageComponent):
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
def toDict(self):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {"qq": str(self.qq)},
|
||||
}
|
||||
|
||||
|
||||
class AtAll(At):
|
||||
qq: str = "all"
|
||||
@@ -368,7 +472,8 @@ class Image(BaseMessageComponent):
|
||||
elif url and url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
@@ -397,25 +502,47 @@ class Image(BaseMessageComponent):
|
||||
bs64_data = file_to_base64(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将图片注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.convert_to_file_path()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type: ComponentType = "Reply"
|
||||
id: T.Union[str, int]
|
||||
"""所引用的消息 ID"""
|
||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||
"""引用的消息段列表"""
|
||||
"""被引用的消息段列表"""
|
||||
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||
"""引用的消息发送者 ID"""
|
||||
"""被引用的消息对应的发送者的 ID"""
|
||||
sender_nickname: T.Optional[str] = ""
|
||||
"""引用的消息发送者昵称"""
|
||||
"""被引用的消息对应的发送者的昵称"""
|
||||
time: T.Optional[int] = 0
|
||||
"""引用的消息发送时间"""
|
||||
"""被引用的消息发送时间"""
|
||||
message_str: T.Optional[str] = ""
|
||||
"""解析后的纯文本消息字符串"""
|
||||
sender_str: T.Optional[str] = ""
|
||||
"""被引用的消息纯文本"""
|
||||
"""被引用的消息解析后的纯文本消息字符串"""
|
||||
|
||||
text: T.Optional[str] = ""
|
||||
"""deprecated"""
|
||||
@@ -460,28 +587,48 @@ class Node(BaseMessageComponent):
|
||||
type: ComponentType = "Node"
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[int] = 0 # qq号
|
||||
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
|
||||
uin: T.Optional[str] = "0" # qq号
|
||||
content: T.Optional[list[BaseMessageComponent]] = []
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0
|
||||
time: T.Optional[int] = 0 # 忽略
|
||||
|
||||
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
|
||||
if isinstance(content, list):
|
||||
_content = None
|
||||
if all(isinstance(item, Node) for item in content):
|
||||
_content = [node.toDict() for node in content]
|
||||
else:
|
||||
_content = ""
|
||||
for chain in content:
|
||||
_content += chain.toString()
|
||||
content = _content
|
||||
elif isinstance(content, Node):
|
||||
content = content.toDict()
|
||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||
if isinstance(content, Node):
|
||||
# back
|
||||
content = [content]
|
||||
super().__init__(content=content, **_)
|
||||
|
||||
def toString(self):
|
||||
# logger.warn("Protocol: node doesn't support stringify")
|
||||
return ""
|
||||
async def to_dict(self):
|
||||
data_content = []
|
||||
for comp in self.content:
|
||||
if isinstance(comp, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await comp.convert_to_base64()
|
||||
data_content.append(
|
||||
{
|
||||
"type": comp.type.lower(),
|
||||
"data": {"file": f"base64://{bs64}"},
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
elif isinstance(comp, (Node, Nodes)):
|
||||
# For Node segments, we recursively convert them to dict
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
else:
|
||||
d = comp.toDict()
|
||||
data_content.append(d)
|
||||
return {
|
||||
"type": "node",
|
||||
"data": {
|
||||
"user_id": str(self.uin),
|
||||
"nickname": self.name,
|
||||
"content": data_content,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Nodes(BaseMessageComponent):
|
||||
@@ -492,7 +639,22 @@ class Nodes(BaseMessageComponent):
|
||||
super().__init__(nodes=nodes, **_)
|
||||
|
||||
def toDict(self):
|
||||
return {"messages": [node.toDict() for node in self.nodes]}
|
||||
"""Deprecated. Use to_dict instead"""
|
||||
ret = {
|
||||
"messages": [],
|
||||
}
|
||||
for node in self.nodes:
|
||||
d = node.toDict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
d = await node.to_dict()
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
|
||||
class Xml(BaseMessageComponent):
|
||||
@@ -552,15 +714,136 @@ class Unknown(BaseMessageComponent):
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
"""
|
||||
目前此消息段只适配了 Napcat。
|
||||
文件消息段
|
||||
"""
|
||||
|
||||
type: ComponentType = "File"
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file: T.Optional[str] = "" # url(本地路径)
|
||||
file_: T.Optional[str] = "" # 本地路径
|
||||
url: T.Optional[str] = "" # url
|
||||
|
||||
def __init__(self, name: str, file: str):
|
||||
super().__init__(name=name, file=file)
|
||||
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||
"""文件消息段。"""
|
||||
super().__init__(name=name, file_=file, url=url)
|
||||
|
||||
@property
|
||||
def file(self) -> str:
|
||||
"""
|
||||
获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
"""
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段"
|
||||
)
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
@file.setter
|
||||
def file(self, value: str):
|
||||
"""
|
||||
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||
|
||||
Args:
|
||||
value (str): 文件路径或URL
|
||||
"""
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
self.url = value
|
||||
else:
|
||||
self.file_ = value
|
||||
|
||||
async def get_file(self, allow_return_url: bool = False) -> str:
|
||||
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
|
||||
|
||||
Args:
|
||||
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
|
||||
注意,如果为 True,也可能返回文件路径。
|
||||
Returns:
|
||||
str: 文件路径或者 http 下载链接
|
||||
"""
|
||||
if allow_return_url and self.url:
|
||||
return self.url
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将文件注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
if not callback_host:
|
||||
raise Exception("未配置 callback_api_base,文件服务不可用")
|
||||
|
||||
file_path = await self.get_file()
|
||||
|
||||
token = await file_token_service.register_file(file_path)
|
||||
|
||||
logger.debug(f"已注册:{callback_host}/api/file/{token}")
|
||||
|
||||
return f"{callback_host}/api/file/{token}"
|
||||
|
||||
async def to_dict(self):
|
||||
"""需要和 toDict 区分开,toDict 是同步方法"""
|
||||
url_or_path = await self.get_file(allow_return_url=True)
|
||||
if url_or_path.startswith("http"):
|
||||
payload_file = url_or_path
|
||||
elif callback_host := astrbot_config.get("callback_api_base"):
|
||||
callback_host = str(callback_host).removesuffix("/")
|
||||
token = await file_token_service.register_file(url_or_path)
|
||||
payload_file = f"{callback_host}/api/file/{token}"
|
||||
logger.debug(f"Generated file callback link: {payload_file}")
|
||||
else:
|
||||
payload_file = url_or_path
|
||||
return {
|
||||
"type": "file",
|
||||
"data": {
|
||||
"name": self.name,
|
||||
"file": payload_file,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
|
||||
@@ -46,28 +46,29 @@ class PreProcessStage(Stage):
|
||||
stt_provider = (
|
||||
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
)
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
if not stt_provider:
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
|
||||
@@ -26,6 +26,13 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -60,15 +67,19 @@ class LLMRequestSubStage(Stage):
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(
|
||||
req, ProviderRequest
|
||||
), "provider_request 必须是 ProviderRequest 类型。"
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
all_contexts = json.loads(req.conversation.history)
|
||||
@@ -146,8 +157,19 @@ class LLMRequestSubStage(Stage):
|
||||
):
|
||||
logger.debug("上下文长度超过限制,将截断。")
|
||||
req.contexts = req.contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length) * 2 :
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(req.contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
req.contexts = req.contexts[index:]
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
@@ -261,6 +283,12 @@ class LLMRequestSubStage(Stage):
|
||||
event.set_extra("tool_call_result", None)
|
||||
yield
|
||||
|
||||
# 暂时直接发出去
|
||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||
event.set_extra("tool_call_img_respond", None)
|
||||
yield
|
||||
|
||||
async def _handle_llm_response(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -371,21 +399,68 @@ class LLMRequestSubStage(Stage):
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if res:
|
||||
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
else:
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 获取处理器,过滤掉平台不兼容的处理器
|
||||
platform_id = event.get_platform_id()
|
||||
star_md = star_map.get(func_tool.handler_module_path)
|
||||
if (
|
||||
star_md and
|
||||
platform_id in star_md.supported_platforms
|
||||
star_md
|
||||
and platform_id in star_md.supported_platforms
|
||||
and not star_md.supported_platforms[platform_id]
|
||||
):
|
||||
logger.debug(
|
||||
|
||||
@@ -58,33 +58,30 @@ class RateLimitStage(Stage):
|
||||
now = datetime.now()
|
||||
|
||||
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
# 检查并处理限流,可能需要多次检查直到满足条件
|
||||
while True:
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
if len(timestamps) < self.rate_limit_count:
|
||||
timestamps.append(now)
|
||||
break
|
||||
else:
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
||||
|
||||
if len(timestamps) >= self.rate_limit_count:
|
||||
# 达到限流阈值,计算下一个窗口的时间
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds()
|
||||
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
self._remove_expired_timestamps(
|
||||
timestamps, now + timedelta(seconds=stall_duration)
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
now = datetime.now()
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
||||
)
|
||||
return event.stop_event()
|
||||
|
||||
def _remove_expired_timestamps(
|
||||
self, timestamps: Deque[datetime], now: datetime
|
||||
|
||||
@@ -12,6 +12,7 @@ from astrbot.core import logger
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -25,37 +26,18 @@ class RespondStage(Stage):
|
||||
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||
Comp.AtAll: lambda comp: True, # @所有人
|
||||
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
|
||||
Comp.Dice: lambda comp: True, # 骰子(未完成)
|
||||
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
|
||||
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
|
||||
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||
Comp.Music: lambda comp: bool(comp._type)
|
||||
and bool(comp.url)
|
||||
and bool(comp.audio), # 音乐
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||
Comp.Node: lambda comp: bool(comp.name)
|
||||
and comp.uin != 0
|
||||
and bool(comp.content), # 一个转发节点
|
||||
Comp.Node: lambda comp: bool(comp.content), # 转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
|
||||
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
|
||||
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
|
||||
Comp.File: lambda comp: bool(comp.file), # 文件
|
||||
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
|
||||
Comp.File: lambda comp: bool(comp.file_ or comp.url),
|
||||
}
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||
|
||||
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||
"reply_with_mention"
|
||||
@@ -126,8 +108,6 @@ class RespondStage(Stage):
|
||||
if comp_type in self._component_validators:
|
||||
if self._component_validators[comp_type](comp):
|
||||
return False
|
||||
else:
|
||||
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
|
||||
|
||||
# 如果所有组件都为空
|
||||
return True
|
||||
@@ -143,12 +123,23 @@ class RespondStage(Stage):
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果直接交付平台适配器处理
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
await event._pre_send()
|
||||
await event.send_streaming(result.async_stream)
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
await event._post_send()
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
# 检查路径映射
|
||||
if mappings := self.platform_settings.get("path_mapping", []):
|
||||
for idx, component in enumerate(result.chain):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
|
||||
await event._pre_send()
|
||||
|
||||
# 检查消息链是否为空
|
||||
@@ -161,6 +152,11 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||
non_record_comps = [
|
||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||
]
|
||||
|
||||
if self.enable_seg and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
@@ -178,8 +174,18 @@ class RespondStage(Stage):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
for rcomp in record_comps:
|
||||
i = await self._calc_comp_interval(rcomp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
for comp in non_record_comps:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
@@ -188,10 +194,18 @@ class RespondStage(Stage):
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
break
|
||||
else:
|
||||
for rcomp in record_comps:
|
||||
try:
|
||||
await event.send(MessageChain([rcomp]))
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
try:
|
||||
await event.send(result)
|
||||
await event.send(MessageChain(non_record_comps))
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -151,9 +152,9 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
split_response = []
|
||||
for line in comp.text.split("\n"):
|
||||
split_response.extend(re.findall(self.regex, line))
|
||||
split_response = re.findall(
|
||||
self.regex, comp.text, re.DOTALL | re.MULTILINE
|
||||
)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
@@ -168,28 +169,55 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = new_chain
|
||||
|
||||
# TTS
|
||||
tts_provider = (
|
||||
self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
)
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
):
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info("TTS 请求: " + comp.text)
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
new_chain.append(
|
||||
Record(file=audio_path, url=audio_path)
|
||||
)
|
||||
else:
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
except BaseException:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
@@ -223,6 +251,14 @@ class ResultDecorateStage(Stage):
|
||||
if url:
|
||||
if url.startswith("http"):
|
||||
result.chain = [Image.fromURL(url)]
|
||||
elif (
|
||||
self.ctx.astrbot_config["t2i_use_file_service"]
|
||||
and self.ctx.astrbot_config["callback_api_base"]
|
||||
):
|
||||
token = await file_token_service.register_file(url)
|
||||
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
result.chain = [Image.fromURL(url)]
|
||||
else:
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
|
||||
@@ -35,10 +35,21 @@ class WakingCheckStage(Stage):
|
||||
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
|
||||
"platform_settings"
|
||||
].get("friend_message_needs_wake_prefix", False)
|
||||
# 是否忽略机器人自己发送的消息
|
||||
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_bot_self_message", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if (
|
||||
self.ignore_bot_self_message
|
||||
and event.get_self_id() == event.get_sender_id()
|
||||
):
|
||||
# 忽略机器人自己发送的消息
|
||||
event.stop_event()
|
||||
return
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||
@@ -126,7 +137,7 @@ class WakingCheckStage(Stage):
|
||||
if self.no_permission_reply:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
|
||||
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import re
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
|
||||
@@ -205,9 +208,26 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
return self.role == "admin"
|
||||
|
||||
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
|
||||
async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
|
||||
"""
|
||||
将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。
|
||||
"""
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
return buffer
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp, gewechat。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
@@ -384,8 +404,13 @@ class AstrMessageEvent(abc.ABC):
|
||||
Args:
|
||||
message (MessageChain): 消息链,具体使用方式请参考文档。
|
||||
"""
|
||||
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
|
||||
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
|
||||
sid = str(uuid.UUID(bytes=hash_obj.digest()))
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
Metric.upload(
|
||||
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
|
||||
)
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
|
||||
@@ -62,6 +62,10 @@ class PlatformManager:
|
||||
from .sources.gewechat.gewechat_platform_adapter import (
|
||||
GewechatPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wechatpadpro":
|
||||
from .sources.wechatpadpro.wechatpadpro_adapter import (
|
||||
WeChatPadProAdapter, # noqa: F401
|
||||
)
|
||||
case "lark":
|
||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||
case "dingtalk":
|
||||
@@ -72,6 +76,8 @@ class PlatformManager:
|
||||
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||
case "wecom":
|
||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.error(
|
||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
|
||||
import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Node,
|
||||
Nodes,
|
||||
Plain,
|
||||
Record,
|
||||
Video,
|
||||
File,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
|
||||
|
||||
class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
@@ -13,44 +23,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
|
||||
"""修复部分字段"""
|
||||
if isinstance(segment, (Image, Record)):
|
||||
# For Image and Record segments, we convert them to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
return {
|
||||
"type": segment.type.lower(),
|
||||
"data": {
|
||||
"file": f"base64://{bs64}",
|
||||
},
|
||||
}
|
||||
elif isinstance(segment, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
elif isinstance(segment, Video):
|
||||
d = await segment.to_dict()
|
||||
return d
|
||||
else:
|
||||
# For other segments, we simply convert them to a dict by calling toDict
|
||||
return segment.toDict()
|
||||
|
||||
@staticmethod
|
||||
async def _parse_onebot_json(message_chain: MessageChain):
|
||||
"""解析成 OneBot json 格式"""
|
||||
ret = []
|
||||
for segment in message_chain.chain:
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d["type"] = "text"
|
||||
d["data"]["text"] = segment.text.strip()
|
||||
# 如果是空文本或者只带换行符的文本,不发送
|
||||
if not d["data"]["text"]:
|
||||
if not segment.text.strip():
|
||||
continue
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
d["data"] = {
|
||||
"file": bs64,
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d["data"] = {
|
||||
"qq": str(segment.qq) # 转换为字符串
|
||||
}
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
|
||||
if not ret:
|
||||
return
|
||||
|
||||
send_one_by_one = False
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
# 转发消息不能和普通消息混在一起发送
|
||||
send_one_by_one = True
|
||||
break
|
||||
|
||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||
send_one_by_one = any(
|
||||
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
|
||||
)
|
||||
if send_one_by_one:
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
@@ -60,7 +72,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
nodes = Nodes([seg])
|
||||
seg = nodes
|
||||
|
||||
payload = seg.toDict()
|
||||
payload = await seg.to_dict()
|
||||
|
||||
if self.get_group_id():
|
||||
payload["group_id"] = self.get_group_id()
|
||||
await self.bot.call_action("send_group_forward_msg", **payload)
|
||||
@@ -69,6 +82,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
await self.bot.call_action(
|
||||
"send_private_forward_msg", **payload
|
||||
)
|
||||
elif isinstance(seg, File):
|
||||
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
[d],
|
||||
)
|
||||
else:
|
||||
await self.bot.send(
|
||||
self.message_obj.raw_message,
|
||||
@@ -78,22 +97,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
if not ret:
|
||||
return
|
||||
await self.bot.send(self.message_obj.raw_message, ret)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
if isinstance(group_id, str) and group_id.isdigit():
|
||||
@@ -108,7 +151,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
members: typing.List[typing.Dict] = await self.bot.call_action(
|
||||
members: List[Dict] = await self.bot.call_action(
|
||||
"get_group_member_list",
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import itertools
|
||||
from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.platform import (
|
||||
@@ -20,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -45,7 +44,12 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
|
||||
self.bot = CQHttp(
|
||||
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
||||
use_ws_reverse=True,
|
||||
import_name="aiocqhttp",
|
||||
api_timeout_sec=180,
|
||||
access_token=platform_config.get(
|
||||
"ws_reverse_token"
|
||||
), # 以防旧版本配置不存在
|
||||
)
|
||||
|
||||
@self.bot.on_request()
|
||||
@@ -99,6 +103,9 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if event["post_type"] == "message":
|
||||
abm = await self._convert_handle_message_event(event)
|
||||
if abm.sender.user_id == "2854196310":
|
||||
# 屏蔽 QQ 管家的消息
|
||||
return
|
||||
elif event["post_type"] == "notice":
|
||||
abm = await self._convert_handle_notice_event(event)
|
||||
elif event["post_type"] == "request":
|
||||
@@ -119,6 +126,12 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
@@ -155,7 +168,9 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if "sub_type" in event:
|
||||
if event["sub_type"] == "poke" and "target_id" in event:
|
||||
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
|
||||
abm.message.append(
|
||||
Poke(qq=str(event["target_id"]), type="poke")
|
||||
) # noqa: F405
|
||||
|
||||
return abm
|
||||
|
||||
@@ -202,82 +217,119 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for m in event.message:
|
||||
t = m["type"]
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
a = None
|
||||
if t == "text":
|
||||
message_str += m["data"]["text"].strip()
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
current_text = "".join(m["data"]["text"] for m in m_group).strip()
|
||||
message_str += current_text
|
||||
a = ComponentTypes[t](text=current_text) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
elif t == "file":
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
for m in m_group:
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
# Napcat
|
||||
ret = None
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_group_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
group_id=event.group_id,
|
||||
)
|
||||
elif abm.type == MessageType.FRIEND_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_private_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
a = File(name="", url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
path = os.path.join("data/temp", file_name)
|
||||
await download_file(m["data"]["url"], path)
|
||||
|
||||
m["data"] = {"file": path, "name": file_name}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
else:
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(
|
||||
action="get_file",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if not ret.get("file", None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret["file"]):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
|
||||
)
|
||||
|
||||
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
elif t == "reply":
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
for m in m_group:
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
elif t == "at":
|
||||
first_at_self_processed = False
|
||||
|
||||
for m in m_group:
|
||||
try:
|
||||
if m["data"]["qq"] == "all":
|
||||
abm.message.append(At(qq="all", name="全体成员"))
|
||||
continue
|
||||
|
||||
at_info = await self.bot.call_action(
|
||||
action="get_stranger_info",
|
||||
user_id=int(m["data"]["qq"]),
|
||||
)
|
||||
if at_info:
|
||||
nickname = at_info.get("nick", "")
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
At(
|
||||
qq=m["data"]["qq"],
|
||||
name=nickname,
|
||||
)
|
||||
)
|
||||
|
||||
if is_at_self and not first_at_self_processed:
|
||||
# 第一个@是机器人,不添加到message_str
|
||||
first_at_self_processed = True
|
||||
else:
|
||||
# 非第一个@机器人或@其他用户,添加到message_str
|
||||
message_str += f" @{nickname} "
|
||||
else:
|
||||
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
else:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -19,6 +20,7 @@ from ...register import register_platform_adapter
|
||||
from astrbot import logger
|
||||
from dingtalk_stream import AckMessage
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class MyEventHandler(dingtalk_stream.EventHandler):
|
||||
@@ -152,7 +154,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
"downloadCode": download_code,
|
||||
"robotCode": robot_code,
|
||||
}
|
||||
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
|
||||
|
||||
@@ -61,7 +61,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
@@ -72,4 +72,4 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
@@ -14,6 +15,7 @@ from astrbot.api.message_components import Plain, Image, At, Record, Video
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
from .xml_data_parser import GeweDataParser
|
||||
@@ -63,7 +65,7 @@ class SimpleGewechatClient:
|
||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/file/<file_id>",
|
||||
"/astrbot-gewechat/file/<file_token>",
|
||||
view_func=self._handle_file,
|
||||
methods=["GET"],
|
||||
)
|
||||
@@ -81,6 +83,11 @@ class SimpleGewechatClient:
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
self.staged_files = {}
|
||||
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def get_token_id(self):
|
||||
"""获取 Gewechat Token。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -143,18 +150,25 @@ class SimpleGewechatClient:
|
||||
content = d["Content"]["string"] # 消息内容
|
||||
|
||||
at_me = False
|
||||
at_wxids = []
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(":\n")
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
if "\u2005" in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||
at_wxids = re.findall(
|
||||
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
||||
msg_source,
|
||||
)
|
||||
|
||||
abm.group_id = from_user_name
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
|
||||
if (
|
||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||
@@ -167,13 +181,12 @@ class SimpleGewechatClient:
|
||||
user_id = from_user_name
|
||||
|
||||
# 检查消息是否由自己发送,若是则忽略
|
||||
if user_id == abm.self_id:
|
||||
logger.info("忽略自己发送的消息")
|
||||
return None
|
||||
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
||||
# if user_id == abm.self_id:
|
||||
# logger.info("忽略自己发送的消息")
|
||||
# return None
|
||||
|
||||
abm.message = []
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
@@ -197,7 +210,19 @@ class SimpleGewechatClient:
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0]
|
||||
try:
|
||||
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
||||
user_real_name = info["nickName"]
|
||||
except Exception as e:
|
||||
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
||||
user_real_name = user_id
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
||||
for wxid in at_wxids:
|
||||
# 群聊里 At 其他人的列表
|
||||
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
||||
abm.message.append(At(qq=wxid, name=_username))
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
@@ -226,7 +251,10 @@ class SimpleGewechatClient:
|
||||
# 语音消息
|
||||
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
|
||||
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir, f"gewe_voice_{abm.message_id}.silk"
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
await f.write(voice_data)
|
||||
@@ -248,9 +276,12 @@ class SimpleGewechatClient:
|
||||
logger.info("消息类型(48):地理位置")
|
||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
abm_data = data_parser.parse_mutil_49()
|
||||
if abm_data:
|
||||
abm.message.append(abm_data)
|
||||
segments = data_parser.parse_mutil_49()
|
||||
if segments:
|
||||
abm.message.extend(segments)
|
||||
for seg in segments:
|
||||
if isinstance(seg, Plain):
|
||||
abm.message_str += seg.text
|
||||
case 51: # 帐号消息同步?
|
||||
logger.info("消息类型(51):帐号消息同步?")
|
||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||
@@ -289,9 +320,33 @@ class SimpleGewechatClient:
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
async def _register_file(self, file_path: str) -> str:
|
||||
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
||||
|
||||
Args:
|
||||
file_path (str): 文件路径。
|
||||
Returns:
|
||||
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
||||
"""
|
||||
async with self.lock:
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
self.staged_files[file_token] = file_path
|
||||
return file_token
|
||||
|
||||
async def _handle_file(self, file_token):
|
||||
async with self.lock:
|
||||
if file_token not in self.staged_files:
|
||||
logger.warning(f"请求的文件 {file_token} 不存在。")
|
||||
return quart.abort(404)
|
||||
if not os.path.exists(self.staged_files[file_token]):
|
||||
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
||||
return quart.abort(404)
|
||||
file_path = self.staged_files[file_token]
|
||||
self.staged_files.pop(file_token, None)
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
@@ -407,8 +462,10 @@ class SimpleGewechatClient:
|
||||
retry_cnt -= 1
|
||||
|
||||
# 需要验证码
|
||||
if os.path.exists("data/temp/gewe_code"):
|
||||
with open("data/temp/gewe_code", "r") as f:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
code_file_path = os.path.join(temp_dir, "gewe_code")
|
||||
if os.path.exists(code_file_path):
|
||||
with open(code_file_path, "r") as f:
|
||||
code = f.read().strip()
|
||||
if not code:
|
||||
logger.warning(
|
||||
@@ -419,9 +476,9 @@ class SimpleGewechatClient:
|
||||
payload["captchCode"] = code
|
||||
logger.info(f"使用验证码: {code}")
|
||||
try:
|
||||
os.remove("data/temp/gewe_code")
|
||||
os.remove(code_file_path)
|
||||
except Exception:
|
||||
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
|
||||
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
@@ -441,17 +498,18 @@ class SimpleGewechatClient:
|
||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
else:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
if "status" in json_blob["data"]:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import re
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.io import save_temp_img, download_file
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -18,6 +21,7 @@ from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
)
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
@@ -80,15 +84,9 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中
|
||||
temp_directory = os.path.abspath("data/temp")
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{client.file_server_url}/{file_id}"
|
||||
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
||||
token = await client._register_file(img_path)
|
||||
img_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Video):
|
||||
@@ -107,20 +105,33 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
video_url = comp.file
|
||||
# 根据 url 下载视频
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
video_path = f"data/temp/{video_filename}"
|
||||
await download_file(video_url, video_path)
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_path = os.path.join(temp_dir, video_filename)
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
|
||||
video_token = await client._register_file(video_path)
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
thumb_path = os.path.join(
|
||||
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
)
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
ff = FFmpeg()
|
||||
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
|
||||
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
||||
ff.options(command)
|
||||
thumb_file_id = os.path.basename(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
|
||||
thumb_token = await client._register_file(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频第一帧失败: {e}")
|
||||
|
||||
# 获取视频时长
|
||||
try:
|
||||
from pyffmpeg import FFprobe
|
||||
@@ -135,15 +146,12 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
logger.error(f"获取时长失败: {e}")
|
||||
video_duration = 10
|
||||
|
||||
file_id = os.path.basename(video_path)
|
||||
video_url = f"{client.file_server_url}/{file_id}"
|
||||
# 发送视频
|
||||
await client.post_video(
|
||||
to_wxid, video_url, thumb_url, video_duration
|
||||
to_wxid, video_callback_url, thumb_url, video_duration
|
||||
)
|
||||
|
||||
# 删除临时视频和缩略图文件
|
||||
if os.path.exists(video_path):
|
||||
os.remove(video_path)
|
||||
# 删除临时缩略图文件
|
||||
if os.path.exists(thumb_path):
|
||||
os.remove(thumb_path)
|
||||
elif isinstance(comp, Record):
|
||||
@@ -151,7 +159,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
record_url = comp.file
|
||||
record_path = await comp.convert_to_file_path()
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||
try:
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
except Exception as e:
|
||||
@@ -160,8 +169,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{client.file_server_url}/{file_id}"
|
||||
token = await client._register_file(silk_path)
|
||||
record_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||
elif isinstance(comp, File):
|
||||
@@ -170,14 +179,17 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
if file_path.startswith("file:///"):
|
||||
file_path = file_path[8:]
|
||||
elif file_path.startswith("http"):
|
||||
await download_file(file_path, f"data/temp/{file_name}")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
temp_file_path = os.path.join(temp_dir, file_name)
|
||||
await download_file(file_path, temp_file_path)
|
||||
file_path = temp_file_path
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
file_id = os.path.basename(file_path)
|
||||
file_url = f"{client.file_server_url}/{file_id}"
|
||||
token = await client._register_file(file_path)
|
||||
file_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await client.post_file(to_wxid, file_url, file_id)
|
||||
await client.post_file(to_wxid, file_url, file_name)
|
||||
elif isinstance(comp, Emoji):
|
||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||
elif isinstance(comp, At):
|
||||
@@ -217,15 +229,36 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Reply,
|
||||
Plain,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
@@ -11,7 +16,7 @@ class GeweDataParser:
|
||||
def _format_to_xml(self):
|
||||
return eT.fromstring(self.data)
|
||||
|
||||
def parse_mutil_49(self):
|
||||
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||
if appmsg_type is None:
|
||||
return
|
||||
@@ -34,13 +39,18 @@ class GeweDataParser:
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||
|
||||
def parse_reply(self) -> Reply | None:
|
||||
def parse_reply(self) -> list[Reply, Plain] | None:
|
||||
"""解析引用消息
|
||||
|
||||
Returns:
|
||||
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
||||
"""
|
||||
try:
|
||||
replied_id = -1
|
||||
replied_uid = 0
|
||||
replied_nickname = ""
|
||||
replied_content = ""
|
||||
content = ""
|
||||
replied_content = "" # 被引用者说的内容
|
||||
content = "" # 引用者说的内容
|
||||
|
||||
root = self._format_to_xml()
|
||||
refermsg = root.find(".//refermsg")
|
||||
@@ -57,22 +67,44 @@ class GeweDataParser:
|
||||
if displayname is not None:
|
||||
replied_nickname = displayname.text
|
||||
if refermsg_content is not None:
|
||||
replied_content = refermsg_content.text
|
||||
# 处理引用嵌套,包括嵌套公众号消息
|
||||
if refermsg_content.text.startswith(
|
||||
"<msg>"
|
||||
) or refermsg_content.text.startswith("<?xml"):
|
||||
try:
|
||||
logger.debug("gewechat: Reference message is nested")
|
||||
refer_root = eT.fromstring(refermsg_content.text)
|
||||
img = refer_root.find("img")
|
||||
if img is not None:
|
||||
replied_content = "[图片]"
|
||||
else:
|
||||
app_msg = refer_root.find("appmsg")
|
||||
refermsg_content_title = app_msg.find("title")
|
||||
logger.debug(
|
||||
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
||||
)
|
||||
replied_content = refermsg_content_title.text
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: nested failed, {e}")
|
||||
# 处理异常情况
|
||||
replied_content = refermsg_content.text
|
||||
else:
|
||||
replied_content = refermsg_content.text
|
||||
|
||||
# 提取引用者说的内容
|
||||
title = root.find(".//appmsg/title")
|
||||
if title is not None:
|
||||
content = title.text
|
||||
|
||||
r = Reply(
|
||||
reply_seg = Reply(
|
||||
id=replied_id,
|
||||
chain=[Plain(content)],
|
||||
chain=[Plain(replied_content)],
|
||||
sender_id=replied_uid,
|
||||
sender_nickname=replied_nickname,
|
||||
sender_str=replied_content,
|
||||
message_str=content,
|
||||
message_str=replied_content,
|
||||
)
|
||||
return r
|
||||
plain_seg = Plain(content)
|
||||
return [reply_seg, plain_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
import lark_oapi as lark
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
@@ -27,22 +31,33 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
_stage.append({"tag": "at", "user_id": comp.qq, "style": []})
|
||||
elif isinstance(comp, AstrBotImage):
|
||||
file_path = ""
|
||||
image_file = None
|
||||
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(comp.file)
|
||||
file_path = image_file_path
|
||||
elif comp.file and comp.file.startswith("base64://"):
|
||||
pass
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
# save as temp file
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
file_path = comp.file
|
||||
|
||||
if image_file is None:
|
||||
image_file = open(file_path, "rb")
|
||||
|
||||
request = (
|
||||
CreateImageRequest.builder()
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image(open(file_path, "rb"))
|
||||
.image(image_file)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
@@ -51,7 +66,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
if not response.success():
|
||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||
image_key = response.data.image_key
|
||||
print(image_key)
|
||||
logger.debug(image_key)
|
||||
ret.append(_stage)
|
||||
ret.append([{"tag": "img", "image_key": image_key}])
|
||||
_stage.clear()
|
||||
@@ -92,7 +107,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
@@ -103,4 +118,4 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -33,7 +33,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
"""流式输出仅支持消息列表私聊"""
|
||||
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
@@ -66,7 +66,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||
self.send_buffer = None
|
||||
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _post_send(self, stream: dict = None):
|
||||
if not self.send_buffer:
|
||||
@@ -97,7 +97,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
"msg_id": self.message_obj.message_id,
|
||||
}
|
||||
|
||||
if not isinstance(source, (botpy.message.Message,botpy.message.DirectMessage)):
|
||||
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
|
||||
payload["msg_seq"] = random.randint(1, 10000)
|
||||
|
||||
match type(source):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
@@ -57,6 +58,14 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
self.enable_command_register = self.config.get(
|
||||
"telegram_command_register", True
|
||||
)
|
||||
self.enable_command_refresh = self.config.get(
|
||||
"telegram_command_auto_refresh", True
|
||||
)
|
||||
self.last_command_hash = None
|
||||
|
||||
self.application = (
|
||||
ApplicationBuilder()
|
||||
.token(self.config["telegram_token"])
|
||||
@@ -94,17 +103,19 @@ class TelegramPlatformAdapter(Platform):
|
||||
async def run(self):
|
||||
await self.application.initialize()
|
||||
await self.application.start()
|
||||
await self.register_commands()
|
||||
|
||||
# TODO 使用更优雅的方式重新注册命令
|
||||
self.scheduler.add_job(
|
||||
self.register_commands,
|
||||
"interval",
|
||||
minutes=5,
|
||||
id="telegram_command_register",
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
self.scheduler.start()
|
||||
if self.enable_command_register:
|
||||
await self.register_commands()
|
||||
|
||||
if self.enable_command_refresh and self.enable_command_register:
|
||||
self.scheduler.add_job(
|
||||
self.register_commands,
|
||||
"interval",
|
||||
seconds=self.config.get("telegram_command_register_interval", 300),
|
||||
id="telegram_command_register",
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
queue = self.application.updater.start_polling()
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
@@ -113,13 +124,17 @@ class TelegramPlatformAdapter(Platform):
|
||||
async def register_commands(self):
|
||||
"""收集所有注册的指令并注册到 Telegram"""
|
||||
try:
|
||||
await self.client.delete_my_commands()
|
||||
commands = self.collect_commands()
|
||||
|
||||
if commands:
|
||||
current_hash = hash(
|
||||
tuple((cmd.command, cmd.description) for cmd in commands)
|
||||
)
|
||||
if current_hash == self.last_command_hash:
|
||||
return
|
||||
self.last_command_hash = current_hash
|
||||
await self.client.delete_my_commands()
|
||||
await self.client.set_my_commands(commands)
|
||||
for cmd in commands:
|
||||
logger.debug(f"已注册指令: /{cmd.command} - {cmd.description}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
|
||||
@@ -129,8 +144,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
command_dict = {}
|
||||
skip_commands = {"start"}
|
||||
|
||||
for handler_md in star_handlers_registry._handlers:
|
||||
handler_metadata = handler_md[1]
|
||||
for handler_md in star_handlers_registry:
|
||||
handler_metadata = handler_md
|
||||
if not star_map[handler_metadata.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_metadata.event_filters:
|
||||
@@ -167,6 +182,10 @@ class TelegramPlatformAdapter(Platform):
|
||||
if not cmd_name or cmd_name in skip_commands:
|
||||
return None
|
||||
|
||||
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
||||
logger.debug(f"跳过无法注册的命令: {cmd_name}")
|
||||
return None
|
||||
|
||||
# Build description.
|
||||
description = handler_metadata.desc or (
|
||||
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
|
||||
@@ -263,10 +282,12 @@ class TelegramPlatformAdapter(Platform):
|
||||
entity.offset + 1 : entity.offset + entity.length
|
||||
]
|
||||
message.message.append(Comp.At(qq=name, name=name))
|
||||
plain_text = (
|
||||
plain_text[: entity.offset]
|
||||
+ plain_text[entity.offset + entity.length :]
|
||||
)
|
||||
# 如果mention是当前bot则移除;否则保留
|
||||
if name.lower() == context.bot.username.lower():
|
||||
plain_text = (
|
||||
plain_text[: entity.offset]
|
||||
+ plain_text[entity.offset + entity.length :]
|
||||
)
|
||||
|
||||
if plain_text:
|
||||
message.message.append(Comp.Plain(plain_text))
|
||||
@@ -339,7 +360,9 @@ class TelegramPlatformAdapter(Platform):
|
||||
self.scheduler.shutdown()
|
||||
|
||||
await self.application.stop()
|
||||
await self.client.delete_my_commands()
|
||||
|
||||
if self.enable_command_register:
|
||||
await self.client.delete_my_commands()
|
||||
|
||||
# 保险起见先判断是否存在updater对象
|
||||
if self.application.updater is not None:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -13,9 +15,20 @@ from astrbot.api.message_components import (
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
# Telegram 的最大消息长度限制
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
SPLIT_PATTERNS = {
|
||||
"paragraph": re.compile(r"\n\n"),
|
||||
"line": re.compile(r"\n"),
|
||||
"sentence": re.compile(r"[.!?。!?]"),
|
||||
"word": re.compile(r"\s"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
@@ -27,8 +40,33 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
|
||||
def _split_message(self, text: str) -> list[str]:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
chunks.append(text)
|
||||
break
|
||||
|
||||
split_point = self.MAX_MESSAGE_LENGTH
|
||||
segment = text[: self.MAX_MESSAGE_LENGTH]
|
||||
|
||||
for _, pattern in self.SPLIT_PATTERNS.items():
|
||||
if matches := list(pattern.finditer(segment)):
|
||||
last_match = matches[-1]
|
||||
split_point = last_match.end()
|
||||
break
|
||||
|
||||
chunks.append(text[:split_point])
|
||||
text = text[split_point:].lstrip()
|
||||
|
||||
return chunks
|
||||
|
||||
async def send_with_client(
|
||||
self, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
@@ -57,25 +95,29 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
if isinstance(i, Plain):
|
||||
if at_user_id and not at_flag:
|
||||
i.text = f"@{at_user_id} " + i.text
|
||||
i.text = f"@{at_user_id} {i.text}"
|
||||
at_flag = True
|
||||
text = i.text
|
||||
try:
|
||||
text = telegramify_markdown.markdownify(
|
||||
i.text, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
|
||||
)
|
||||
return
|
||||
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
|
||||
chunks = self._split_message(i.text)
|
||||
for chunk in chunks:
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
chunk, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
await client.send_message(
|
||||
text=md_text, parse_mode="MarkdownV2", **payload
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 send failed: {e}. Using plain text instead."
|
||||
)
|
||||
await client.send_message(text=chunk, **payload)
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
@@ -91,7 +133,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
message_thread_id = None
|
||||
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
@@ -126,7 +168,8 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
@@ -143,17 +186,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
continue
|
||||
|
||||
# Plain
|
||||
if not message_id:
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
else:
|
||||
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
@@ -172,6 +205,18 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 更新上次编辑的时间
|
||||
else:
|
||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
delta = ""
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
|
||||
try:
|
||||
if delta and current_content != delta:
|
||||
@@ -183,16 +228,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
text=markdown_text,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
parse_mode="MarkdownV2"
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
|
||||
await self.client.edit_message_text(
|
||||
text=delta,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id
|
||||
text=delta, chat_id=payload["chat_id"], message_id=message_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -17,6 +17,7 @@ from astrbot.core import web_chat_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class QueueListener:
|
||||
@@ -40,7 +41,8 @@ class WebChatAdapter(Platform):
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat", description="webchat", id=self.config.get("id")
|
||||
|
||||
@@ -6,8 +6,9 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
imgs_dir = "data/webchat/imgs"
|
||||
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
@@ -106,7 +107,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
@@ -121,4 +122,4 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
await super().send_streaming(generator)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -0,0 +1,707 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.platform import Platform, PlatformMetadata
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .wechatpadpro_message_event import WeChatPadProMessageEvent
|
||||
|
||||
|
||||
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
|
||||
class WeChatPadProAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self._shutdown_event = None
|
||||
self.wxnewpass = None
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="wechatpadpro",
|
||||
description="WeChatPadPro 消息平台适配器",
|
||||
id=self.config.get("id", "wechatpadpro"),
|
||||
)
|
||||
|
||||
# 保存配置信息
|
||||
self.admin_key = self.config.get("admin_key")
|
||||
self.host = self.config.get("host")
|
||||
self.port = self.config.get("port")
|
||||
self.active_mesasge_poll: bool = self.config.get(
|
||||
"wpp_active_message_poll", False
|
||||
)
|
||||
self.active_message_poll_interval: int = self.config.get(
|
||||
"wpp_active_message_poll_interval", 5
|
||||
)
|
||||
self.base_url = f"http://{self.host}:{self.port}"
|
||||
self.auth_key = None # 用于保存生成的授权码
|
||||
self.wxid = None # 用于保存登录成功后的 wxid
|
||||
self.credentials_file = os.path.join(
|
||||
get_astrbot_data_path(), "wechatpadpro_credentials.json"
|
||||
) # 持久化文件路径
|
||||
self.ws_handle_task = None
|
||||
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
启动平台适配器的运行实例。
|
||||
"""
|
||||
logger.info("WeChatPadPro 适配器正在启动...")
|
||||
|
||||
if loaded_credentials := self.load_credentials():
|
||||
self.auth_key = loaded_credentials.get("auth_key")
|
||||
self.wxid = loaded_credentials.get("wxid")
|
||||
|
||||
isLoginIn = await self.check_online_status()
|
||||
|
||||
# 检查在线状态
|
||||
if self.auth_key and isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
|
||||
# 如果在线,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
else:
|
||||
# 1. 生成授权码
|
||||
if not self.auth_key:
|
||||
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码。")
|
||||
await self.generate_auth_key()
|
||||
|
||||
# 2. 获取登录二维码
|
||||
if not isLoginIn:
|
||||
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
|
||||
qr_code_url = await self.get_login_qr_code()
|
||||
|
||||
if qr_code_url:
|
||||
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
|
||||
else:
|
||||
logger.error("无法获取登录二维码。")
|
||||
return
|
||||
|
||||
# 3. 检测扫码状态
|
||||
login_successful = await self.check_login_status()
|
||||
|
||||
if login_successful:
|
||||
logger.info("登录成功,WeChatPadPro适配器已连接。")
|
||||
else:
|
||||
logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。")
|
||||
await self.terminate()
|
||||
return
|
||||
|
||||
# 登录成功后,连接 WebSocket 接收消息
|
||||
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
|
||||
|
||||
self._shutdown_event = asyncio.Event()
|
||||
await self._shutdown_event.wait()
|
||||
logger.info("WeChatPadPro 适配器已停止。")
|
||||
|
||||
def load_credentials(self):
|
||||
"""
|
||||
从文件中加载 auth_key 和 wxid。
|
||||
"""
|
||||
if os.path.exists(self.credentials_file):
|
||||
try:
|
||||
with open(self.credentials_file, "r") as f:
|
||||
credentials = json.load(f)
|
||||
logger.info("成功加载 WeChatPadPro 凭据。")
|
||||
return credentials
|
||||
except Exception as e:
|
||||
logger.error(f"加载 WeChatPadPro 凭据失败: {e}")
|
||||
return None
|
||||
|
||||
def save_credentials(self):
|
||||
"""
|
||||
将 auth_key 和 wxid 保存到文件。
|
||||
"""
|
||||
credentials = {
|
||||
"auth_key": self.auth_key,
|
||||
"wxid": self.wxid,
|
||||
}
|
||||
try:
|
||||
# 确保数据目录存在
|
||||
data_dir = os.path.dirname(self.credentials_file)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
with open(self.credentials_file, "w") as f:
|
||||
json.dump(credentials, f)
|
||||
logger.info("成功保存 WeChatPadPro 凭据。")
|
||||
except Exception as e:
|
||||
logger.error(f"保存 WeChatPadPro 凭据失败: {e}")
|
||||
|
||||
async def check_online_status(self):
|
||||
"""
|
||||
检查 WeChatPadPro 设备是否在线。
|
||||
"""
|
||||
url = f"{self.base_url}/login/GetLoginStatus"
|
||||
params = {"key": self.auth_key}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response_data = await response.json()
|
||||
# 根据提供的在线接口返回示例,成功状态码是 200,loginState 为 1 表示在线
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
login_state = response_data.get("Data", {}).get("loginState")
|
||||
if login_state == 1:
|
||||
logger.info("WeChatPadPro 设备当前在线。")
|
||||
return True
|
||||
# login_state == 3 为离线状态
|
||||
elif login_state == 3:
|
||||
logger.info(
|
||||
"WeChatPadPro 设备不在线。"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"未知的在线状态: {login_state:}"
|
||||
)
|
||||
return False
|
||||
# Code == 300 为微信退出状态。
|
||||
elif response.status == 200 and response_data.get("Code") == 300:
|
||||
logger.info(
|
||||
"WeChatPadPro 设备已退出。"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"检查在线状态失败: {response.status}, {response_data}"
|
||||
)
|
||||
return False
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"检查在线状态时发生错误: {e}")
|
||||
return False
|
||||
|
||||
async def generate_auth_key(self):
|
||||
"""
|
||||
生成授权码。
|
||||
"""
|
||||
url = f"{self.base_url}/admin/GenAuthKey1"
|
||||
params = {"key": self.admin_key}
|
||||
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
# 修正成功判断条件和授权码提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 授权码在 Data 字段的列表中
|
||||
if (
|
||||
response_data.get("Data")
|
||||
and isinstance(response_data["Data"], list)
|
||||
and len(response_data["Data"]) > 0
|
||||
):
|
||||
self.auth_key = response_data["Data"][0]
|
||||
logger.info("成功获取授权码")
|
||||
else:
|
||||
logger.error(
|
||||
f"生成授权码成功但未找到授权码: {response_data}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"生成授权码失败: {response.status}, {response_data}"
|
||||
)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"生成授权码时发生错误: {e}")
|
||||
|
||||
async def get_login_qr_code(self):
|
||||
"""
|
||||
获取登录二维码地址。
|
||||
"""
|
||||
url = f"{self.base_url}/login/GetLoginQrCodeNew"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {} # 根据文档,这个接口的 body 可以为空
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
# 修正成功判断条件和数据提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 二维码地址在 Data.QrCodeUrl 字段中
|
||||
if response_data.get("Data") and response_data["Data"].get(
|
||||
"QrCodeUrl"
|
||||
):
|
||||
return response_data["Data"]["QrCodeUrl"]
|
||||
else:
|
||||
logger.error(
|
||||
f"获取登录二维码成功但未找到二维码地址: {response_data}"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
logger.error(
|
||||
f"获取登录二维码失败: {response.status}, {response_data}"
|
||||
)
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取登录二维码时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def check_login_status(self):
|
||||
"""
|
||||
循环检测扫码状态。
|
||||
尝试 6 次后跳出循环,添加倒计时。
|
||||
返回 True 如果登录成功,否则返回 False。
|
||||
"""
|
||||
url = f"{self.base_url}/login/CheckLoginStatus"
|
||||
params = {"key": self.auth_key}
|
||||
|
||||
attempts = 0 # 初始化尝试次数
|
||||
max_attempts = 36 # 最大尝试次数
|
||||
countdown = 180 # 倒计时时长
|
||||
logger.info(f"请在 {countdown} 秒内扫码登录。")
|
||||
while attempts < max_attempts:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response_data = await response.json()
|
||||
# 成功判断条件和数据提取路径
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
if (
|
||||
response_data.get("Data")
|
||||
and response_data["Data"].get("state") is not None
|
||||
):
|
||||
status = response_data["Data"]["state"]
|
||||
logger.info(
|
||||
f"第 {attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}秒"
|
||||
)
|
||||
if status == 2: # 状态 2 表示登录成功
|
||||
self.wxid = response_data["Data"].get("wxid")
|
||||
self.wxnewpass = response_data["Data"].get(
|
||||
"wxnewpass"
|
||||
)
|
||||
logger.info(
|
||||
f"登录成功,wxid: {self.wxid}, wxnewpass: {self.wxnewpass}"
|
||||
)
|
||||
self.save_credentials() # 登录成功后保存凭据
|
||||
return True
|
||||
elif status == -2: # 二维码过期
|
||||
logger.error("二维码已过期,请重新获取。")
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
f"检测登录状态成功但未找到登录状态: {response_data}"
|
||||
)
|
||||
elif response_data.get("Code") == 300:
|
||||
# "不存在状态"
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
f"检测登录状态失败: {response.status}, {response_data}"
|
||||
)
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
await asyncio.sleep(5)
|
||||
attempts += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"检测登录状态时发生错误: {e}")
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
attempts += 1
|
||||
await asyncio.sleep(5) # 每隔5秒检测一次
|
||||
logger.warning("登录检测超过最大尝试次数,退出检测。")
|
||||
return False
|
||||
|
||||
async def connect_websocket(self):
|
||||
"""
|
||||
建立 WebSocket 连接并处理接收到的消息。
|
||||
"""
|
||||
os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}"
|
||||
ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}"
|
||||
logger.info(
|
||||
f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
async with websockets.connect(ws_url) as websocket:
|
||||
logger.info("WebSocket 连接成功。")
|
||||
# 设置空闲超时重连
|
||||
wait_time = (
|
||||
self.active_message_poll_interval
|
||||
if self.active_mesasge_poll
|
||||
else 120
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
websocket.recv(), timeout=wait_time
|
||||
)
|
||||
# logger.debug(message) # 不显示原始消息内容
|
||||
asyncio.create_task(self.handle_websocket_message(message))
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"WebSocket 连接空闲超过 {wait_time} s")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
logger.info("WebSocket 连接正常关闭。")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def handle_websocket_message(self, message: str):
|
||||
"""
|
||||
处理从 WebSocket 接收到的消息。
|
||||
"""
|
||||
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||
try:
|
||||
message_data = json.loads(message)
|
||||
if (
|
||||
message_data.get("msg_id") is not None
|
||||
and message_data.get("from_user_name") is not None
|
||||
):
|
||||
abm = await self.convert_message(message_data)
|
||||
if abm:
|
||||
# 创建 WeChatPadProMessageEvent 实例
|
||||
message_event = WeChatPadProMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=self.meta(),
|
||||
session_id=abm.session_id,
|
||||
# 传递适配器实例,以便在事件中调用 send 方法
|
||||
adapter=self,
|
||||
)
|
||||
# 提交事件到事件队列
|
||||
self.commit_event(message_event)
|
||||
else:
|
||||
logger.warning(f"收到未知结构的 WebSocket 消息: {message_data}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析 WebSocket 消息为 JSON: {message}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
|
||||
|
||||
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||
"""
|
||||
将 WeChatPadPro 原始消息转换为 AstrBotMessage。
|
||||
"""
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = raw_message
|
||||
abm.message_id = str(raw_message.get("msg_id"))
|
||||
abm.timestamp = raw_message.get("create_time")
|
||||
abm.self_id = self.wxid
|
||||
|
||||
if int(time.time()) - abm.timestamp > 180:
|
||||
logger.warning(
|
||||
f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}。"
|
||||
)
|
||||
return None
|
||||
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
content = raw_message.get("content", {}).get("str", "")
|
||||
push_content = raw_message.get("push_content", "")
|
||||
msg_type = raw_message.get("msg_type")
|
||||
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
|
||||
# 如果是机器人自己发送的消息、回显消息或系统消息,忽略
|
||||
if from_user_name == self.wxid:
|
||||
logger.info("忽略来自自己的消息。")
|
||||
return None
|
||||
|
||||
if from_user_name in ["weixin", "newsapp", "newsapp_wechat"]:
|
||||
logger.info("忽略来自微信团队的消息。")
|
||||
return None
|
||||
|
||||
# 先判断群聊/私聊并设置基本属性
|
||||
if await self._process_chat_type(
|
||||
abm, raw_message, from_user_name, to_user_name, content, push_content
|
||||
):
|
||||
# 再根据消息类型处理消息内容
|
||||
await self._process_message_content(abm, raw_message, msg_type, content)
|
||||
|
||||
return abm
|
||||
return None
|
||||
|
||||
async def _process_chat_type(
|
||||
self,
|
||||
abm: AstrBotMessage,
|
||||
raw_message: dict,
|
||||
from_user_name: str,
|
||||
to_user_name: str,
|
||||
content: str,
|
||||
push_content: str,
|
||||
):
|
||||
"""
|
||||
判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。
|
||||
"""
|
||||
if from_user_name == "weixin":
|
||||
return False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = from_user_name
|
||||
|
||||
parts = content.split(":\n", 1)
|
||||
sender_wxid = parts[0] if len(parts) == 2 else ""
|
||||
abm.sender = MessageMember(user_id=sender_wxid, nickname="")
|
||||
|
||||
# 获取群聊发送者的nickname
|
||||
if sender_wxid:
|
||||
accurate_nickname = await self._get_group_member_nickname(
|
||||
abm.group_id, sender_wxid
|
||||
)
|
||||
if accurate_nickname:
|
||||
abm.sender.nickname = accurate_nickname
|
||||
|
||||
# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
|
||||
if self.unique_session:
|
||||
abm.session_id = f"{from_user_name}_{to_user_name}"
|
||||
else:
|
||||
abm.session_id = from_user_name
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.group_id = ""
|
||||
nick_name = ""
|
||||
if push_content and " : " in push_content:
|
||||
nick_name = push_content.split(" : ")[0]
|
||||
abm.sender = MessageMember(user_id=from_user_name, nickname=nick_name)
|
||||
abm.session_id = from_user_name
|
||||
return True
|
||||
|
||||
async def _get_group_member_nickname(
|
||||
self, group_id: str, member_wxid: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
通过接口获取群成员的昵称。
|
||||
"""
|
||||
url = f"{self.base_url}/group/GetChatroomMemberDetail"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"ChatRoomName": group_id,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
response_data = await response.json()
|
||||
if response.status == 200 and response_data.get("Code") == 200:
|
||||
# 从返回数据中查找对应成员的昵称
|
||||
member_list = (
|
||||
response_data.get("Data", {})
|
||||
.get("member_data", {})
|
||||
.get("chatroom_member_list", [])
|
||||
)
|
||||
for member in member_list:
|
||||
if member.get("user_name") == member_wxid:
|
||||
return member.get("nick_name")
|
||||
logger.warning(
|
||||
f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"获取群成员详情失败: {response.status}, {response_data}"
|
||||
)
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取群成员详情时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _download_raw_image(
|
||||
self, from_user_name: str, to_user_name: str, msg_id: int
|
||||
):
|
||||
"""下载原始图片。"""
|
||||
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {
|
||||
"CompressType": 0,
|
||||
"FromUserName": from_user_name,
|
||||
"MsgId": msg_id,
|
||||
"Section": {"DataLen": 61440, "StartPos": 0},
|
||||
"ToUserName": to_user_name,
|
||||
"TotalLen": 0,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
logger.error(f"下载图片失败: {response.status}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _process_message_content(
|
||||
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
|
||||
):
|
||||
"""
|
||||
根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。
|
||||
"""
|
||||
if msg_type == 1: # 文本消息
|
||||
abm.message_str = content
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
parts = content.split(":\n", 1)
|
||||
if len(parts) == 2:
|
||||
abm.message_str = parts[1]
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
else:
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
else: # 私聊消息
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
elif msg_type == 3:
|
||||
# 图片消息
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
msg_id = raw_message.get("msg_id")
|
||||
image_resp = await self._download_raw_image(
|
||||
from_user_name, to_user_name, msg_id
|
||||
)
|
||||
image_bs64_data = (
|
||||
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||
)
|
||||
if image_bs64_data:
|
||||
abm.message.append(Image.fromBase64(image_bs64_data))
|
||||
elif msg_type == 47:
|
||||
# 视频消息 (注意:表情消息也是 47,需要区分)
|
||||
logger.warning("收到视频消息,待实现。")
|
||||
elif msg_type == 50:
|
||||
# 语音/视频
|
||||
logger.warning("收到语音/视频消息,待实现。")
|
||||
elif msg_type == 49:
|
||||
# 引用消息
|
||||
logger.warning("收到引用消息,待实现。")
|
||||
else:
|
||||
logger.warning(f"收到未处理的消息类型: {msg_type}。")
|
||||
|
||||
async def terminate(self):
|
||||
"""
|
||||
终止一个平台的运行实例。
|
||||
"""
|
||||
logger.info("终止 WeChatPadPro 适配器。")
|
||||
try:
|
||||
if self.ws_handle_task:
|
||||
self.ws_handle_task.cancel()
|
||||
self._shutdown_event.set()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""
|
||||
得到一个平台的元数据。
|
||||
"""
|
||||
return self.metadata
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
dummy_message_obj = AstrBotMessage()
|
||||
dummy_message_obj.session_id = session.session_id
|
||||
# 根据 session_id 判断消息类型
|
||||
if "@chatroom" in session.session_id:
|
||||
dummy_message_obj.type = MessageType.GROUP_MESSAGE
|
||||
dummy_message_obj.group_id = session.session_id
|
||||
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||
else:
|
||||
dummy_message_obj.type = MessageType.FRIEND_MESSAGE
|
||||
dummy_message_obj.group_id = ""
|
||||
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
|
||||
sending_event = WeChatPadProMessageEvent(
|
||||
message_str="",
|
||||
message_obj=dummy_message_obj,
|
||||
platform_meta=self.meta(),
|
||||
session_id=session.session_id,
|
||||
adapter=self,
|
||||
)
|
||||
# 调用实例方法 send
|
||||
await sending_event.send(message_chain)
|
||||
|
||||
async def get_contact_list(self):
|
||||
"""
|
||||
获取联系人列表。
|
||||
"""
|
||||
url = f"{self.base_url}/friend/GetContactList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = (
|
||||
result.get("Data", {})
|
||||
.get("ContactList", {})
|
||||
.get("contactUsernameList", [])
|
||||
)
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人列表时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def get_contact_details_list(
|
||||
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
获取联系人详情列表。
|
||||
"""
|
||||
if room_wx_id_list is None:
|
||||
room_wx_id_list = []
|
||||
if user_names is None:
|
||||
user_names = []
|
||||
url = f"{self.base_url}/friend/GetContactDetailsList"
|
||||
params = {"key": self.auth_key}
|
||||
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.error(f"获取联系人详情列表失败: {response.status}")
|
||||
return None
|
||||
result = await response.json()
|
||||
if result.get("Code") == 200 and result.get("Data"):
|
||||
contact_list = result.get("Data", {}).get("contactList", {})
|
||||
return contact_list
|
||||
else:
|
||||
logger.error(f"获取联系人详情列表失败: {result}")
|
||||
return None
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取联系人详情列表时发生错误: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,117 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image as PILImage # 使用别名避免冲突
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Image, Plain # Import Image
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wechatpadpro_adapter import WeChatPadProAdapter
|
||||
|
||||
|
||||
class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
adapter: "WeChatPadProAdapter", # 传递适配器实例
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.message_obj = message_obj # Save the full message object
|
||||
self.adapter = adapter # Save the adapter instance
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for comp in message.chain:
|
||||
await asyncio.sleep(1)
|
||||
if isinstance(comp, Plain):
|
||||
await self._send_text(session, comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
await self._send_image(session, comp)
|
||||
await super().send(message)
|
||||
|
||||
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
||||
b64 = await comp.convert_to_base64()
|
||||
raw = self._validate_base64(b64)
|
||||
b64c = self._compress_image(raw)
|
||||
payload = {
|
||||
"MsgItem": [
|
||||
{"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendImageNewMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
async def _send_text(self, session: aiohttp.ClientSession, text: str):
|
||||
if (
|
||||
self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息
|
||||
and self.adapter.settings.get(
|
||||
"reply_with_mention", False
|
||||
) # 检查适配器设置是否启用 reply_with_mention
|
||||
and self.message_obj.sender # 确保有发送者信息
|
||||
and (
|
||||
self.message_obj.sender.user_id or self.message_obj.sender.nickname
|
||||
) # 确保发送者有 ID 或昵称
|
||||
):
|
||||
# 优先使用 nickname,如果没有则使用 user_id
|
||||
mention_text = (
|
||||
self.message_obj.sender.nickname or self.message_obj.sender.user_id
|
||||
)
|
||||
message_text = f"@{mention_text} {text}"
|
||||
# logger.info(f"已添加 @ 信息: {message_text}")
|
||||
else:
|
||||
message_text = text
|
||||
payload = {
|
||||
"MsgItem": [
|
||||
{"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id}
|
||||
]
|
||||
}
|
||||
url = f"{self.adapter.base_url}/message/SendTextMessage"
|
||||
await self._post(session, url, payload)
|
||||
|
||||
@staticmethod
|
||||
def _validate_base64(b64: str) -> bytes:
|
||||
return base64.b64decode(b64, validate=True)
|
||||
|
||||
@staticmethod
|
||||
def _compress_image(data: bytes) -> str:
|
||||
img = PILImage.open(io.BytesIO(data))
|
||||
buf = io.BytesIO()
|
||||
if img.format == "JPEG":
|
||||
img.save(buf, "JPEG", quality=80)
|
||||
else:
|
||||
if img.mode in ("RGBA", "P"):
|
||||
img = img.convert("RGB")
|
||||
img.save(buf, "JPEG", quality=80)
|
||||
# logger.info("图片处理完成!!!")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
async def _post(self, session, url, payload):
|
||||
params = {"key": self.adapter.auth_key}
|
||||
try:
|
||||
async with session.post(url, params=params, json=payload) as resp:
|
||||
data = await resp.json()
|
||||
if resp.status != 200 or data.get("Code") != 200:
|
||||
logger.error(f"{url} failed: {resp.status} {data}")
|
||||
except Exception as e:
|
||||
logger.error(f"{url} error: {e}")
|
||||
|
||||
|
||||
# TODO: 添加对其他消息组件类型的处理 (Record, Video, At等)
|
||||
# elif isinstance(component, Record):
|
||||
# pass
|
||||
# elif isinstance(component, Video):
|
||||
# pass
|
||||
# elif isinstance(component, At):
|
||||
# pass
|
||||
# ...
|
||||
@@ -1,28 +1,33 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import quart
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
from wechatpy.enterprise import WeChatClient, parse_message
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise.messages import ImageMessage, TextMessage, VoiceMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.messages import BaseMessage
|
||||
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Image, Plain, Record
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot.core import logger
|
||||
from requests import Response
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.enterprise import parse_message
|
||||
from .wecom_event import WecomPlatformEvent
|
||||
from .wecom_kf import WeChatKF
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -131,9 +136,40 @@ class WecomPlatformAdapter(Platform):
|
||||
self.config["corpid"].strip(),
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
# 微信客服
|
||||
self.kf_name = self.config.get("kf_name", None)
|
||||
if self.kf_name:
|
||||
# inject
|
||||
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||
self.client.kf = self.wechat_kf_api
|
||||
self.client.kf_message = self.wechat_kf_message_api
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
async def callback(msg):
|
||||
async def callback(msg: BaseMessage):
|
||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||
|
||||
def get_latest_msg_item() -> dict | None:
|
||||
token = msg._data["Token"]
|
||||
kfid = msg._data["OpenKfId"]
|
||||
has_more = 1
|
||||
ret = {}
|
||||
while has_more:
|
||||
ret = self.wechat_kf_api.sync_msg(token, kfid)
|
||||
has_more = ret["has_more"]
|
||||
msg_list = ret.get("msg_list", [])
|
||||
if msg_list:
|
||||
return msg_list[-1]
|
||||
return None
|
||||
|
||||
msg_new = await asyncio.get_event_loop().run_in_executor(
|
||||
None, get_latest_msg_item
|
||||
)
|
||||
if msg_new:
|
||||
await self.convert_wechat_kf_message(msg_new)
|
||||
return
|
||||
await self.convert_message(msg)
|
||||
|
||||
self.server.callback = callback
|
||||
@@ -153,9 +189,39 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
if self.kf_name:
|
||||
try:
|
||||
acc_list = (
|
||||
await loop.run_in_executor(
|
||||
None, self.wechat_kf_api.get_account_list
|
||||
)
|
||||
).get("account_list", [])
|
||||
logger.debug(f"获取到微信客服列表: {str(acc_list)}")
|
||||
for acc in acc_list:
|
||||
name = acc.get("name", None)
|
||||
if name != self.kf_name:
|
||||
continue
|
||||
open_kfid = acc.get("open_kfid", None)
|
||||
if not open_kfid:
|
||||
logger.error("获取微信客服失败,open_kfid 为空。")
|
||||
logger.debug(f"Found open_kfid: {str(open_kfid)}")
|
||||
kf_url = (
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self.wechat_kf_api.add_contact_way,
|
||||
open_kfid,
|
||||
"astrbot_placeholder",
|
||||
)
|
||||
).get("url", "")
|
||||
logger.info(
|
||||
f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(self, msg):
|
||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if msg.type == "text":
|
||||
assert isinstance(msg, TextMessage)
|
||||
@@ -191,14 +257,15 @@ class WecomPlatformAdapter(Platform):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, msg.media_id
|
||||
)
|
||||
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
@@ -218,10 +285,43 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
return
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||
msgtype = msg.get("msgtype", None)
|
||||
external_userid = msg.get("external_userid", None)
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = msg
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
abm.self_id = msg["open_kfid"]
|
||||
abm.sender = MessageMember(external_userid, external_userid)
|
||||
abm.session_id = external_userid
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8])
|
||||
if msgtype == "text":
|
||||
text = msg.get("text", {}).get("content", "").strip()
|
||||
abm.message = [Plain(text=text)]
|
||||
abm.message_str = text
|
||||
elif msgtype == "image":
|
||||
media_id = msg.get("image", {}).get("media_id", "")
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, media_id
|
||||
)
|
||||
path = f"data/temp/wechat_kf_{media_id}.jpg"
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
abm.message = [Image(file=path, url=path)]
|
||||
abm.message_str = "[图片]"
|
||||
else:
|
||||
logger.warning(f"未实现的微信客服消息事件: {msg}")
|
||||
return
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = WecomPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import pydub
|
||||
@@ -33,59 +37,150 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
):
|
||||
pass
|
||||
|
||||
async def split_plain(self, plain: str) -> list[str]:
|
||||
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
|
||||
|
||||
Args:
|
||||
plain (str): 要分割的长文本
|
||||
Returns:
|
||||
list[str]: 分割后的文本列表
|
||||
"""
|
||||
if len(plain) <= 2048:
|
||||
return [plain]
|
||||
else:
|
||||
result = []
|
||||
start = 0
|
||||
while start < len(plain):
|
||||
# 剩下的字符串长度<2048时结束
|
||||
if start + 2048 >= len(plain):
|
||||
result.append(plain[start:])
|
||||
break
|
||||
|
||||
# 向前搜索分割标点符号
|
||||
end = min(start + 2048, len(plain))
|
||||
cut_position = end
|
||||
for i in range(end, start, -1):
|
||||
if i < len(plain) and plain[i - 1] in [
|
||||
"。",
|
||||
"!",
|
||||
"?",
|
||||
".",
|
||||
"!",
|
||||
"?",
|
||||
"\n",
|
||||
";",
|
||||
";",
|
||||
]:
|
||||
cut_position = i
|
||||
break
|
||||
|
||||
# 没找到合适的位置分割, 直接切分
|
||||
if cut_position == end and end < len(plain):
|
||||
cut_position = end
|
||||
|
||||
result.append(plain[start:cut_position])
|
||||
start = cut_position
|
||||
|
||||
return result
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
self.client.message.send_text(
|
||||
message_obj.self_id, message_obj.session_id, comp.text
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
is_wechat_kf = hasattr(self.client, "kf_message")
|
||||
if is_wechat_kf:
|
||||
# 微信客服
|
||||
kf_message_api = getattr(self.client, "kf_message", None)
|
||||
if not kf_message_api:
|
||||
logger.warning("未找到微信客服发送消息方法。")
|
||||
return
|
||||
assert isinstance(kf_message_api, WeChatKFMessage)
|
||||
user_id = self.get_sender_id()
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传图片返回: {response}")
|
||||
kf_message_api.send_image(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
else:
|
||||
# 企业微信应用
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
self.client.message.send_text(
|
||||
message_obj.self_id, message_obj.session_id, chunk
|
||||
)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
@@ -96,4 +191,4 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
278
astrbot/core/platform/sources/wecom/wecom_kf.py
Normal file
278
astrbot/core/platform/sources/wecom/wecom_kf.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014-2020 messense
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
|
||||
class WeChatKF(BaseWeChatAPI):
|
||||
"""
|
||||
微信客服接口
|
||||
|
||||
https://work.weixin.qq.com/api/doc/90000/90135/94670
|
||||
"""
|
||||
|
||||
def sync_msg(self, token, open_kfid, cursor="", limit=1000):
|
||||
"""
|
||||
微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收)
|
||||
、客户点击菜单消息的回复消息,可以通过该接口获取具体的消息内容和事件。不支持读取通过发送消息接口发送的消息。
|
||||
支持的消息类型:文本、图片、语音、视频、文件、位置、链接、名片、小程序、事件。
|
||||
|
||||
|
||||
:param token: 回调事件返回的token字段,10分钟内有效;可不填,如果不填接口有严格的频率限制。不多于128字节
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param cursor: 上一次调用时返回的next_cursor,第一次拉取可以不填。不多于64字节
|
||||
:param limit: 期望请求的数据量,默认值和最大值都为1000。
|
||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||
return self._post("kf/sync_msg", data=data)
|
||||
|
||||
def get_service_state(self, open_kfid, external_userid):
|
||||
"""
|
||||
获取会话状态
|
||||
|
||||
ID 状态 说明
|
||||
0 未处理 新会话接入。可选择:1.直接用API自动回复消息。2.放进待接入池等待接待人员接待。3.指定接待人员进行接待
|
||||
1 由智能助手接待 可使用API回复消息。可选择转入待接入池或者指定接待人员处理。
|
||||
2 待接入池排队中 在待接入池中排队等待接待人员接入。可选择转为指定人员接待
|
||||
3 由人工接待 人工接待中。可选择结束会话
|
||||
4 已结束 会话已经结束。不允许变更会话状态,等待用户重新发起咨询
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
}
|
||||
return self._post("kf/service_state/get", data=data)
|
||||
|
||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||
"""
|
||||
变更会话状态
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:param service_state: 当前的会话状态,状态定义参考概述中的表格
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
"service_state": service_state,
|
||||
}
|
||||
if servicer_userid:
|
||||
data["servicer_userid"] = servicer_userid
|
||||
return self._post("kf/service_state/trans", data=data)
|
||||
|
||||
def get_servicer_list(self, open_kfid):
|
||||
"""
|
||||
获取接待人员列表
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
return self._get("kf/servicer/list", params=data)
|
||||
|
||||
def add_servicer(self, open_kfid, userid_list):
|
||||
"""
|
||||
添加接待人员
|
||||
添加指定客服帐号的接待人员。
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param userid_list: 接待人员userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(userid_list, list):
|
||||
userid_list = [userid_list]
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"userid_list": userid_list,
|
||||
}
|
||||
return self._post("kf/servicer/add", data=data)
|
||||
|
||||
def del_servicer(self, open_kfid, userid_list):
|
||||
"""
|
||||
删除接待人员
|
||||
从客服帐号删除接待人员
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param userid_list: 接待人员userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(userid_list, list):
|
||||
userid_list = [userid_list]
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"userid_list": userid_list,
|
||||
}
|
||||
return self._post("kf/servicer/del", data=data)
|
||||
|
||||
def batchget_customer(self, external_userid_list):
|
||||
"""
|
||||
客户基本信息获取
|
||||
|
||||
:param external_userid_list: external_userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(external_userid_list, list):
|
||||
external_userid_list = [external_userid_list]
|
||||
|
||||
data = {
|
||||
"external_userid_list": external_userid_list,
|
||||
}
|
||||
return self._post("kf/customer/batchget", data=data)
|
||||
|
||||
def get_account_list(self):
|
||||
"""
|
||||
获取客服帐号列表
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
return self._get("kf/account/list")
|
||||
|
||||
def add_contact_way(self, open_kfid, scene):
|
||||
"""
|
||||
获取客服帐号链接
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param scene: 场景值,字符串类型,由开发者自定义。不多于32字节;字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "scene": scene}
|
||||
return self._post("kf/add_contact_way", data=data)
|
||||
|
||||
def get_upgrade_service_config(self):
|
||||
"""
|
||||
获取配置的专员与客户群
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
return self._get("kf/customer/get_upgrade_service_config")
|
||||
|
||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||
"""
|
||||
为客户升级为专员或客户群服务
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:param service_type: 表示是升级到专员服务还是客户群服务。1:专员服务。2:客户群服务
|
||||
:param member: 推荐的服务专员,type等于1时有效
|
||||
:param groupchat: 推荐的客户群,type等于2时有效
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
"type": service_type,
|
||||
}
|
||||
if service_type == 1:
|
||||
data["member"] = member
|
||||
else:
|
||||
data["groupchat"] = groupchat
|
||||
return self._post("kf/customer/upgrade_service", data=data)
|
||||
|
||||
def cancel_upgrade_service(self, open_kfid, external_userid):
|
||||
"""
|
||||
为客户取消推荐
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {"open_kfid": open_kfid, "external_userid": external_userid}
|
||||
return self._post("kf/customer/cancel_upgrade_service", data=data)
|
||||
|
||||
def send_msg_on_event(self, code, msgtype, msg_content, msgid=None):
|
||||
"""
|
||||
当特定的事件回调消息包含code字段,可以此code为凭证,调用该接口给用户发送相应事件场景下的消息,如客服欢迎语。
|
||||
支持发送消息类型:文本、菜单消息。
|
||||
|
||||
:param code: 事件响应消息对应的code。通过事件回调下发,仅可使用一次。
|
||||
:param msgtype: 消息类型。对不同的msgtype,有相应的结构描述,详见消息类型
|
||||
:param msg_content: 目前支持文本与菜单消息,具体查看文档
|
||||
:param msgid: 消息ID。如果请求参数指定了msgid,则原样返回,否则系统自动生成并返回。不多于32字节;
|
||||
字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {"code": code, "msgtype": msgtype}
|
||||
if msgid:
|
||||
data["msgid"] = msgid
|
||||
data.update(msg_content)
|
||||
return self._post("kf/send_msg_on_event", data=data)
|
||||
|
||||
def get_corp_statistic(self, start_time, end_time, open_kfid=None):
|
||||
"""
|
||||
获取「客户数据统计」企业汇总数据
|
||||
|
||||
:param start_time: 开始时间
|
||||
:param end_time: 结束时间
|
||||
:param open_kfid: 客服帐号ID
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||
return self._post("kf/get_corp_statistic", data=data)
|
||||
|
||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||
"""
|
||||
获取「客户数据统计」接待人员明细数据
|
||||
|
||||
:param start_time: 开始时间
|
||||
:param end_time: 结束时间
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param servicer_userid: 接待人员
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"servicer_userid": servicer_userid,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
return self._post("kf/get_servicer_statistic", data=data)
|
||||
|
||||
def account_update(self, open_kfid, name, media_id):
|
||||
"""
|
||||
修改客服账号
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param name: 客服名称
|
||||
:param media_id: 客服头像临时素材
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "name": name, "media_id": media_id}
|
||||
return self._post("kf/account/update", data=data)
|
||||
159
astrbot/core/platform/sources/wecom/wecom_kf_message.py
Normal file
159
astrbot/core/platform/sources/wecom/wecom_kf_message.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014-2020 messense
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from optionaldict import optionaldict
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
class WeChatKFMessage(BaseWeChatAPI):
|
||||
"""
|
||||
发送微信客服消息
|
||||
|
||||
https://work.weixin.qq.com/api/doc/90000/90135/94677
|
||||
|
||||
支持:
|
||||
* 文本消息
|
||||
* 图片消息
|
||||
* 语音消息
|
||||
* 视频消息
|
||||
* 文件消息
|
||||
* 图文链接
|
||||
* 小程序
|
||||
* 菜单消息
|
||||
* 地理位置
|
||||
"""
|
||||
|
||||
def send(self, user_id, open_kfid, msgid="", msg=None):
|
||||
"""
|
||||
当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。
|
||||
注意仅当微信客户在主动发送消息给客服后的48小时内,企业可发送消息给客户,最多可发送5条消息;若用户继续发送消息,企业可再次下发消息。
|
||||
支持发送消息类型:文本、图片、语音、视频、文件、图文、小程序、菜单消息、地理位置。
|
||||
|
||||
:param user_id: 指定接收消息的客户UserID
|
||||
:param open_kfid: 指定发送消息的客服帐号ID
|
||||
:param msgid: 指定消息ID
|
||||
:param tag_ids: 标签ID列表。
|
||||
:param msg: 发送消息的 dict 对象
|
||||
:type msg: dict | None
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
msg = msg or {}
|
||||
data = {
|
||||
"touser": user_id,
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
if msgid:
|
||||
data["msgid"] = msgid
|
||||
data.update(msg)
|
||||
return self._post("kf/send_msg", data=data)
|
||||
|
||||
def send_text(self, user_id, open_kfid, content, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "text", "text": {"content": content}},
|
||||
)
|
||||
|
||||
def send_image(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "image", "image": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_voice(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "voice", "voice": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_video(self, user_id, open_kfid, media_id, msgid=""):
|
||||
video_data = optionaldict()
|
||||
video_data["media_id"] = media_id
|
||||
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "video", "video": dict(video_data)},
|
||||
)
|
||||
|
||||
def send_file(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "file", "file": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_articles_link(self, user_id, open_kfid, article, msgid=""):
|
||||
articles_data = {
|
||||
"title": article["title"],
|
||||
"desc": article["desc"],
|
||||
"url": article["url"],
|
||||
"thumb_media_id": article["thumb_media_id"],
|
||||
}
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||
)
|
||||
|
||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "msgmenu",
|
||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||
},
|
||||
)
|
||||
|
||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "location",
|
||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||
},
|
||||
)
|
||||
|
||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "miniprogram",
|
||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,286 @@
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import quart
|
||||
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot.core import logger
|
||||
from requests import Response
|
||||
|
||||
from wechatpy.utils import check_signature
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy import parse_message
|
||||
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class WecomServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.token = config.get("token")
|
||||
self.encoding_aes_key = config.get("encoding_aes_key")
|
||||
self.appid = config.get("appid")
|
||||
self.server.add_url_rule(
|
||||
"/callback/command", view_func=self.verify, methods=["GET"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/callback/command", view_func=self.callback_command, methods=["POST"]
|
||||
)
|
||||
self.crypto = WeChatCrypto(self.token, self.encoding_aes_key, self.appid)
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.callback = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
logger.info(f"验证请求有效性: {quart.request.args}")
|
||||
|
||||
args = quart.request.args
|
||||
if not args.get("signature", None):
|
||||
logger.error("未知的响应,请检查回调地址是否填写正确。")
|
||||
return "err"
|
||||
try:
|
||||
check_signature(
|
||||
self.token,
|
||||
args.get("signature"),
|
||||
args.get("timestamp"),
|
||||
args.get("nonce"),
|
||||
)
|
||||
logger.info("验证请求有效性成功。")
|
||||
return args.get("echostr", "empty")
|
||||
except InvalidSignatureException:
|
||||
logger.error("验证请求有效性失败,签名异常,请检查配置。")
|
||||
return "err"
|
||||
|
||||
async def callback_command(self):
|
||||
data = await quart.request.get_data()
|
||||
msg_signature = quart.request.args.get("msg_signature")
|
||||
timestamp = quart.request.args.get("timestamp")
|
||||
nonce = quart.request.args.get("nonce")
|
||||
try:
|
||||
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
|
||||
except InvalidSignatureException:
|
||||
logger.error("解密失败,签名异常,请检查配置。")
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
result_xml = await self.callback(msg)
|
||||
if not result_xml:
|
||||
return "success"
|
||||
if isinstance(result_xml, str):
|
||||
return result_xml
|
||||
|
||||
return "success"
|
||||
|
||||
async def start_polling(self):
|
||||
logger.info(
|
||||
f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。"
|
||||
)
|
||||
await self.server.run_task(
|
||||
host=self.callback_server_host,
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
|
||||
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
|
||||
class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.client_self_id = uuid.uuid4().hex[:8]
|
||||
self.api_base_url = platform_config.get(
|
||||
"api_base_url", "https://api.weixin.qq.com/cgi-bin/"
|
||||
)
|
||||
self.active_send_mode = self.config.get("active_send_mode", False)
|
||||
|
||||
if not self.api_base_url:
|
||||
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
|
||||
|
||||
if self.api_base_url.endswith("/"):
|
||||
self.api_base_url = self.api_base_url[:-1]
|
||||
if not self.api_base_url.endswith("/cgi-bin"):
|
||||
self.api_base_url += "/cgi-bin"
|
||||
|
||||
if not self.api_base_url.endswith("/"):
|
||||
self.api_base_url += "/"
|
||||
|
||||
self.server = WecomServer(self._event_queue, self.config)
|
||||
|
||||
self.client = WeChatClient(
|
||||
self.config["appid"].strip(),
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||
# msgid -> Future
|
||||
self.wexin_event_workers: dict[str, asyncio.Future] = {}
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
try:
|
||||
if self.active_send_mode:
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[msg.id]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息时出现异常: {e}")
|
||||
|
||||
self.server.callback = callback
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"weixin_official_account",
|
||||
"微信公众平台 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(
|
||||
self, msg, future: asyncio.Future = None
|
||||
) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, msg.media_id
|
||||
)
|
||||
path = f"data/temp/wecom_{msg.media_id}.amr"
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
|
||||
audio = AudioSegment.from_file(path)
|
||||
audio.export(path_wav, format="wav")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。"
|
||||
)
|
||||
path_wav = path
|
||||
return
|
||||
|
||||
abm.message_str = ""
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
future.set_result(None)
|
||||
return
|
||||
# 很不优雅 :(
|
||||
abm.raw_message = {
|
||||
"message": msg,
|
||||
"future": future,
|
||||
"active_send_mode": self.active_send_mode,
|
||||
}
|
||||
logger.info(f"abm: {abm}")
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = WeixinOfficialAccountPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client,
|
||||
)
|
||||
self.commit_event(message_event)
|
||||
|
||||
def get_client(self) -> WeChatClient:
|
||||
return self.client
|
||||
|
||||
async def terminate(self):
|
||||
self.server.shutdown_event.set()
|
||||
try:
|
||||
await self.server.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("微信公众平台 适配器已被优雅地关闭")
|
||||
@@ -0,0 +1,185 @@
|
||||
import uuid
|
||||
import asyncio
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.replies import TextReply, ImageReply, VoiceReply
|
||||
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
try:
|
||||
import pydub
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: WeChatClient,
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(
|
||||
client: WeChatClient, message: MessageChain, user_name: str
|
||||
):
|
||||
pass
|
||||
|
||||
async def split_plain(self, plain: str) -> list[str]:
|
||||
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
|
||||
|
||||
Args:
|
||||
plain (str): 要分割的长文本
|
||||
Returns:
|
||||
list[str]: 分割后的文本列表
|
||||
"""
|
||||
if len(plain) <= 2048:
|
||||
return [plain]
|
||||
else:
|
||||
result = []
|
||||
start = 0
|
||||
while start < len(plain):
|
||||
# 剩下的字符串长度<2048时结束
|
||||
if start + 2048 >= len(plain):
|
||||
result.append(plain[start:])
|
||||
break
|
||||
|
||||
# 向前搜索分割标点符号
|
||||
end = min(start + 2048, len(plain))
|
||||
cut_position = end
|
||||
for i in range(end, start, -1):
|
||||
if i < len(plain) and plain[i - 1] in [
|
||||
"。",
|
||||
"!",
|
||||
"?",
|
||||
".",
|
||||
"!",
|
||||
"?",
|
||||
"\n",
|
||||
";",
|
||||
";",
|
||||
]:
|
||||
cut_position = i
|
||||
break
|
||||
|
||||
# 没找到合适的位置分割, 直接切分
|
||||
if cut_position == end and end < len(plain):
|
||||
cut_position = end
|
||||
|
||||
result.append(plain[start:cut_position])
|
||||
start = cut_position
|
||||
|
||||
return result
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
if active_send_mode:
|
||||
self.client.message.send_text(message_obj.sender.user_id, chunk)
|
||||
else:
|
||||
reply = TextReply(
|
||||
content=chunk,
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信公众平台上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信公众平台上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信公众平台上传图片返回: {response}")
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_image(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信公众平台上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信公众平台上传语音失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||
|
||||
|
||||
if active_send_mode:
|
||||
self.client.message.send_voice(
|
||||
message_obj.sender.user_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
@@ -19,6 +19,7 @@ class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -155,7 +156,9 @@ class ProviderRequest:
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": self.prompt}],
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
|
||||
@@ -3,19 +3,31 @@ import json
|
||||
import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
@@ -87,26 +99,87 @@ class MCPClient:
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
self.tools: List[mcp.Tool] = []
|
||||
self.server_errlogs: List[str] = []
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict):
|
||||
"""Connect to an MCP server
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = mcp_server_config.copy()
|
||||
cfg.pop("active", None)
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||
cfg = cfg["mcpServers"][key_0]
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(server_params)
|
||||
)
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(self.stdio, self.write)
|
||||
)
|
||||
if "url" in cfg:
|
||||
is_sse = True
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
is_sse = False
|
||||
if is_sse:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=LogPipe(
|
||||
level=logging.ERROR,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
@@ -204,8 +277,7 @@ class FuncCall:
|
||||
}
|
||||
```
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
|
||||
data_dir = get_astrbot_data_path()
|
||||
|
||||
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
|
||||
if not os.path.exists(mcp_json_file):
|
||||
@@ -260,6 +332,13 @@ class FuncCall:
|
||||
if data["name"] in self.mcp_client_event:
|
||||
self.mcp_client_event[data["name"]].set()
|
||||
self.mcp_client_event.pop(data["name"], None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (
|
||||
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
for name in self.mcp_client_dict.keys():
|
||||
# await self._terminate_mcp_client(name)
|
||||
@@ -267,6 +346,7 @@ class FuncCall:
|
||||
if name in self.mcp_client_event:
|
||||
self.mcp_client_event[name].set()
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self, name: str, cfg: dict, event: asyncio.Event
|
||||
@@ -278,6 +358,9 @@ class FuncCall:
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
await self._terminate_mcp_client(name)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
@@ -289,10 +372,10 @@ class FuncCall:
|
||||
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
await mcp_client.connect_to_server(config)
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
@@ -314,13 +397,16 @@ class FuncCall:
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return True
|
||||
return
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
# 发生错误时确保客户端被清理
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
return False
|
||||
return
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
@@ -339,7 +425,7 @@ class FuncCall:
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field = False) -> list:
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
@@ -386,28 +472,86 @@ class FuncCall:
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
def get_func_desc_google_genai_style(self) -> Dict:
|
||||
def get_func_desc_google_genai_style(self) -> dict:
|
||||
"""
|
||||
获得 Google GenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
|
||||
# Gemini API 支持的数据类型和格式
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""转换 schema 为 Gemini API 格式"""
|
||||
|
||||
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
# 暂时指定默认为null
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties: # 只在有非空属性时添加
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
**({"parameters": convert_schema(f.parameters)}),
|
||||
}
|
||||
for f in self.func_list
|
||||
if f.active
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
func_declaration = {"name": f.name, "description": f.description}
|
||||
|
||||
# 检查并添加非空的properties参数
|
||||
params = f.parameters if isinstance(f.parameters, dict) else {}
|
||||
params = copy.deepcopy(params)
|
||||
if params.get("properties", {}):
|
||||
properties = params["properties"]
|
||||
for key, value in properties.items():
|
||||
if "default" in value:
|
||||
del value["default"]
|
||||
params["properties"] = properties
|
||||
func_declaration["parameters"] = params
|
||||
|
||||
tools.append(func_declaration)
|
||||
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
@@ -21,9 +21,9 @@ class ProviderManager:
|
||||
self.selected_provider_id = sp.get("curr_provider")
|
||||
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
self.provider_enabled = self.provider_settings.get("enable", False)
|
||||
self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
# self.provider_enabled = self.provider_settings.get("enable", False)
|
||||
# self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
# self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
@@ -98,9 +98,13 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
self.default_provider_inst: Provider = None
|
||||
"""默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例"""
|
||||
self.curr_provider_inst: Provider = None
|
||||
"""当前使用的 Provider 实例"""
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
@@ -119,14 +123,9 @@ class ProviderManager:
|
||||
for provider_config in self.providers_config:
|
||||
await self.load_provider(provider_config)
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
|
||||
if self.stt_enabled and not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
|
||||
if self.tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
self.default_provider_inst = self.inst_map.get(self.selected_provider_id)
|
||||
if not self.default_provider_inst and self.provider_insts:
|
||||
self.default_provider_inst = self.provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(
|
||||
@@ -202,6 +201,22 @@ class ProviderManager:
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
@@ -233,15 +248,12 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_stt_provider_id == provider_config["id"]
|
||||
and self.stt_enabled
|
||||
):
|
||||
if self.selected_stt_provider_id == provider_config["id"]:
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_enabled:
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
@@ -254,15 +266,12 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_tts_provider_id == provider_config["id"]
|
||||
and self.tts_enabled
|
||||
):
|
||||
if self.selected_tts_provider_id == provider_config["id"]:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_enabled:
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
@@ -279,17 +288,22 @@ class ProviderManager:
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if (
|
||||
self.selected_provider_id == provider_config["id"]
|
||||
and self.provider_enabled
|
||||
):
|
||||
if self.selected_provider_id == provider_config["id"]:
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_enabled:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -310,11 +324,7 @@ class ProviderManager:
|
||||
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif (
|
||||
self.curr_provider_inst is None
|
||||
and len(self.provider_insts) > 0
|
||||
and self.provider_enabled
|
||||
):
|
||||
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||
logger.info(
|
||||
@@ -323,11 +333,7 @@ class ProviderManager:
|
||||
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif (
|
||||
self.curr_stt_provider_inst is None
|
||||
and len(self.stt_provider_insts) > 0
|
||||
and self.stt_enabled
|
||||
):
|
||||
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||
logger.info(
|
||||
@@ -336,11 +342,7 @@ class ProviderManager:
|
||||
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif (
|
||||
self.curr_tts_provider_inst is None
|
||||
and len(self.tts_provider_insts) > 0
|
||||
and self.tts_enabled
|
||||
):
|
||||
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||
logger.info(
|
||||
|
||||
@@ -179,3 +179,25 @@ class TTSProvider(AbstractProvider):
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""获取文本的音频,返回音频文件路径"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class EmbeddingProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的向量"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
210
astrbot/core/provider/sources/azure_tts_source.py
Normal file
210
astrbot/core/provider/sources/azure_tts_source.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
import hashlib
|
||||
import random
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from httpx import AsyncClient, Timeout
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
TEMP_DIR = Path("data/temp/azure_tts")
|
||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
class OTTSProvider:
|
||||
def __init__(self, config: Dict):
|
||||
self.skey = config["OTTS_SKEY"]
|
||||
self.api_url = config["OTTS_URL"]
|
||||
self.auth_time_url = config["OTTS_AUTH_TIME"]
|
||||
self.time_offset = 0
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.client = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(timeout=self.timeout)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
async def _sync_time(self):
|
||||
try:
|
||||
response = await self.client.get(self.auth_time_url)
|
||||
response.raise_for_status()
|
||||
server_time = int(response.json()["timestamp"])
|
||||
local_time = int(time.time())
|
||||
self.time_offset = server_time - local_time
|
||||
self.last_sync_time = local_time
|
||||
except Exception as e:
|
||||
if time.time() - self.last_sync_time > 3600:
|
||||
raise RuntimeError("时间同步失败") from e
|
||||
|
||||
async def _generate_signature(self) -> str:
|
||||
await self._sync_time()
|
||||
timestamp = int(time.time()) + self.time_offset
|
||||
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
|
||||
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
||||
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
||||
|
||||
async def get_audio(self, text: str, voice_params: Dict) -> str:
|
||||
file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav"
|
||||
signature = await self._generate_signature()
|
||||
for attempt in range(self.retry_count):
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.api_url}?sign={signature}",
|
||||
data={
|
||||
"text": text,
|
||||
"voice": voice_params["voice"],
|
||||
"style": voice_params["style"],
|
||||
"role": voice_params["role"],
|
||||
"rate": voice_params["rate"],
|
||||
"volume": voice_params["volume"]
|
||||
},
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"UAK": "AstrBot/AzureTTS"
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
async for chunk in response.aiter_bytes(4096):
|
||||
f.write(chunk)
|
||||
return str(file_path.resolve())
|
||||
except Exception as e:
|
||||
if attempt == self.retry_count - 1:
|
||||
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
|
||||
class AzureNativeProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
|
||||
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
||||
raise ValueError("无效的Azure订阅密钥")
|
||||
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
||||
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
self.client = None
|
||||
self.token = None
|
||||
self.token_expire = 0
|
||||
self.voice_params = {
|
||||
"voice": provider_config.get("azure_tts_voice", "zh-CN-YunxiaNeural"),
|
||||
"style": provider_config.get("azure_tts_style", "cheerful"),
|
||||
"role": provider_config.get("azure_tts_role", "Boy"),
|
||||
"rate": provider_config.get("azure_tts_rate", "1"),
|
||||
"volume": provider_config.get("azure_tts_volume", "100")
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
|
||||
})
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
async def _refresh_token(self):
|
||||
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
||||
response = await self.client.post(
|
||||
token_url,
|
||||
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
self.token = response.text
|
||||
self.token_expire = time.time() + 540
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
if not self.token or time.time() > self.token_expire:
|
||||
await self._refresh_token()
|
||||
file_path = TEMP_DIR / f"azure-{uuid.uuid4()}.wav"
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis'
|
||||
xmlns:mstts='http://www.w3.org/2001/mstts' xml:lang='zh-CN'>
|
||||
<voice name='{escape(self.voice_params["voice"])}'>
|
||||
<mstts:express-as style='{escape(self.voice_params["style"])}'
|
||||
role='{escape(self.voice_params["role"])}'>
|
||||
<prosody rate='{escape(self.voice_params["rate"])}'
|
||||
volume='{escape(self.voice_params["volume"])}'>
|
||||
{escape(text)}
|
||||
</prosody>
|
||||
</mstts:express-as>
|
||||
</voice>
|
||||
</speak>"""
|
||||
response = await self.client.post(
|
||||
self.endpoint,
|
||||
content=ssml,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"User-Agent": f"AstrBot/{VERSION}"
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
for chunk in response.iter_bytes(4096):
|
||||
f.write(chunk)
|
||||
return str(file_path.resolve())
|
||||
|
||||
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
|
||||
class AzureTTSProvider(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
key_value = provider_config.get("azure_tts_subscription_key", "")
|
||||
self.provider = self._parse_provider(key_value, provider_config)
|
||||
|
||||
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
||||
if key_value.lower().startswith("other["):
|
||||
try:
|
||||
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError("无效的other[...]格式,应形如 other[{...}]")
|
||||
json_str = match.group(1).strip()
|
||||
otts_config = json.loads(json_str)
|
||||
required = {"OTTS_SKEY", "OTTS_URL", "OTTS_AUTH_TIME"}
|
||||
if missing := required - otts_config.keys():
|
||||
raise ValueError(f"缺少OTTS参数: {', '.join(missing)}")
|
||||
return OTTSProvider(otts_config)
|
||||
except json.JSONDecodeError as e:
|
||||
error_msg = (
|
||||
f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n"
|
||||
f"错误详情: {e.msg}\n"
|
||||
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}"
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except KeyError as e:
|
||||
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
|
||||
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
|
||||
return AzureNativeProvider(config, self.provider_settings)
|
||||
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
if isinstance(self.provider, OTTSProvider):
|
||||
async with self.provider as provider:
|
||||
return await provider.get_audio(
|
||||
text,
|
||||
{
|
||||
"voice": self.provider_config.get("azure_tts_voice"),
|
||||
"style": self.provider_config.get("azure_tts_style"),
|
||||
"role": self.provider_config.get("azure_tts_role"),
|
||||
"rate": self.provider_config.get("azure_tts_rate"),
|
||||
"volume": self.provider_config.get("azure_tts_volume")
|
||||
}
|
||||
)
|
||||
else:
|
||||
async with self.provider as provider:
|
||||
return await provider.get_audio(text)
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import dashscope
|
||||
import uuid
|
||||
import asyncio
|
||||
@@ -5,6 +6,7 @@ from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -21,16 +23,16 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
||||
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
voice=self.voice,
|
||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||
audio = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.synthesizer.call, text, self.timeout_ms
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
@@ -10,6 +10,7 @@ from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
@@ -227,7 +228,8 @@ class ProviderDify(Provider):
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "audio":
|
||||
# 仅支持 wav
|
||||
path = f"data/temp/{item['filename']}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"{item['filename']}.wav")
|
||||
await download_file(item["url"], path)
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
case "video":
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
"""
|
||||
edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库
|
||||
@@ -40,9 +41,9 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
|
||||
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
|
||||
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
|
||||
|
||||
# 构建 Edge TTS 参数
|
||||
kwargs = {"text": text, "voice": self.voice}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
import ormsgpack
|
||||
from pydantic import BaseModel, conint
|
||||
@@ -6,6 +7,7 @@ from typing import Annotated, Literal
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class ServeReferenceAudio(BaseModel):
|
||||
@@ -87,7 +89,8 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
|
||||
@@ -1,121 +1,55 @@
|
||||
import base64
|
||||
import aiohttp
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Personality, Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class SimpleGoogleGenAIClient:
|
||||
def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None:
|
||||
self.api_key = api_key
|
||||
if api_base.endswith("/"):
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession(trust_env=True)
|
||||
self.timeout = timeout
|
||||
class SuppressNonTextPartsWarning(logging.Filter):
|
||||
"""过滤 Gemini SDK 中的非文本部分警告"""
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
async with self.client.get(request_url, timeout=self.timeout) as resp:
|
||||
response = await resp.json()
|
||||
def filter(self, record):
|
||||
return "there are non-text parts in the response" not in record.getMessage()
|
||||
|
||||
models = []
|
||||
for model in response["models"]:
|
||||
if "generateContent" in model["supportedGenerationMethods"]:
|
||||
models.append(model["name"].replace("models/", ""))
|
||||
return models
|
||||
|
||||
async def generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str = "gemini-1.5-flash",
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {"parts": {"text": system_instruction}}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
)
|
||||
async with self.client.post(
|
||||
request_url, json=payload, timeout=self.timeout
|
||||
) as resp:
|
||||
if "application/json" in resp.headers.get("Content-Type"):
|
||||
try:
|
||||
response = await resp.json()
|
||||
except Exception as e:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise e
|
||||
return response
|
||||
else:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise Exception("Gemini 返回了非 json 数据: ")
|
||||
logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning())
|
||||
|
||||
async def stream_generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str = "gemini-1.5-flash",
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {"parts": {"text": system_instruction}}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
"stream": True,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
|
||||
)
|
||||
async with self.client.post(
|
||||
request_url, json=payload, timeout=self.timeout
|
||||
) as resp:
|
||||
async for line in resp.content:
|
||||
if line:
|
||||
yield line
|
||||
|
||||
@register_provider_adapter(
|
||||
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
|
||||
)
|
||||
class ProviderGoogleGenAI(Provider):
|
||||
CATEGORY_MAPPING = {
|
||||
"harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
"hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
"sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
"dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
}
|
||||
|
||||
THRESHOLD_MAPPING = {
|
||||
"BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE,
|
||||
"BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
"BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||
"BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
@@ -131,196 +65,416 @@ class ProviderGoogleGenAI(Provider):
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 180)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.client = SimpleGoogleGenAIClient(
|
||||
api_key=self.chosen_api_key,
|
||||
api_base=provider_config.get("api_base", None),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
self.api_keys: list = provider_config.get("key", [])
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||
|
||||
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
||||
if self.api_base and self.api_base.endswith("/"):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
self._init_client()
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self._init_safety_settings()
|
||||
|
||||
safety_mapping = {
|
||||
"harassment": "HARM_CATEGORY_HARASSMENT",
|
||||
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
}
|
||||
def _init_client(self) -> None:
|
||||
"""初始化Gemini客户端"""
|
||||
self.client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
http_options=types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
),
|
||||
).aio
|
||||
|
||||
self.safety_settings = []
|
||||
def _init_safety_settings(self) -> None:
|
||||
"""初始化安全设置"""
|
||||
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||
for config_key, harm_category in safety_mapping.items():
|
||||
if threshold := user_safety_config.get(config_key):
|
||||
self.safety_settings.append(
|
||||
{"category": harm_category, "threshold": threshold}
|
||||
)
|
||||
self.safety_settings = [
|
||||
types.SafetySetting(
|
||||
category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str]
|
||||
)
|
||||
for config_key, harm_category in self.CATEGORY_MAPPING.items()
|
||||
if (threshold_str := user_safety_config.get(config_key))
|
||||
and threshold_str in self.THRESHOLD_MAPPING
|
||||
]
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
|
||||
"""处理API错误,返回是否需要重试"""
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
self.set_key(random.choice(keys))
|
||||
logger.info(
|
||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: Optional[FuncCall] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
modalities: Optional[list[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
if not modalities:
|
||||
modalities = ["Text"]
|
||||
|
||||
# 流式输出不支持图片模态
|
||||
if (
|
||||
self.provider_settings.get("streaming_response", False)
|
||||
and "Image" in modalities
|
||||
):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
|
||||
tool_list = None
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
|
||||
if native_coderunner:
|
||||
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
|
||||
if native_search:
|
||||
logger.warning("已启用代码执行工具,搜索工具将被忽略")
|
||||
if tools:
|
||||
logger.warning("已启用代码执行工具,函数工具将被忽略")
|
||||
elif native_search:
|
||||
tool_list = [types.Tool(google_search=types.GoogleSearch())]
|
||||
if tools:
|
||||
logger.warning("已启用搜索工具,函数工具将被忽略")
|
||||
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
max_output_tokens=payloads.get("max_tokens")
|
||||
or payloads.get("maxOutputTokens"),
|
||||
top_p=payloads.get("top_p") or payloads.get("topP"),
|
||||
top_k=payloads.get("top_k") or payloads.get("topK"),
|
||||
frequency_penalty=payloads.get("frequency_penalty")
|
||||
or payloads.get("frequencyPenalty"),
|
||||
presence_penalty=payloads.get("presence_penalty")
|
||||
or payloads.get("presencePenalty"),
|
||||
stop_sequences=payloads.get("stop") or payloads.get("stopSequences"),
|
||||
response_logprobs=payloads.get("response_logprobs")
|
||||
or payloads.get("responseLogprobs"),
|
||||
logprobs=payloads.get("logprobs"),
|
||||
seed=payloads.get("seed"),
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
),
|
||||
24576,
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
)
|
||||
|
||||
def _prepare_conversation(self, payloads: dict) -> list[types.Content]:
|
||||
"""准备 Gemini SDK 的 Content 列表"""
|
||||
|
||||
def create_text_part(text: str) -> types.Part:
|
||||
content_a = text if text else " "
|
||||
if not text:
|
||||
logger.warning("文本内容为空,已添加空格占位")
|
||||
return types.Part.from_text(text=content_a)
|
||||
|
||||
def process_image_url(image_url_dict: dict) -> types.Part:
|
||||
url = image_url_dict["url"]
|
||||
mime_type = url.split(":")[1].split(";")[0]
|
||||
image_bytes = base64.b64decode(url.split(",", 1)[1])
|
||||
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
|
||||
|
||||
def append_or_extend(
|
||||
contents: list[types.Content],
|
||||
part: list[types.Part],
|
||||
content_cls: type[types.Content],
|
||||
) -> None:
|
||||
if contents and isinstance(contents[-1], content_cls):
|
||||
contents[-1].parts.extend(part)
|
||||
else:
|
||||
contents.append(content_cls(parts=part))
|
||||
|
||||
gemini_contents: list[types.Content] = []
|
||||
native_tool_enabled = any(
|
||||
[
|
||||
self.provider_config.get("gm_native_coderunner", False),
|
||||
self.provider_config.get("gm_native_search", False),
|
||||
]
|
||||
)
|
||||
for message in payloads["messages"]:
|
||||
role, content = message["role"], message.get("content")
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
types.Part.from_text(text=item["text"] or " ")
|
||||
if item["type"] == "text"
|
||||
else process_image_url(item["image_url"])
|
||||
for item in content
|
||||
]
|
||||
else:
|
||||
parts = [create_text_part(content)]
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
elif role == "assistant":
|
||||
if content:
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
elif not native_tool_enabled and "tool_calls" in message:
|
||||
parts = [
|
||||
types.Part.from_function_call(
|
||||
name=tool["function"]["name"],
|
||||
args=json.loads(tool["function"]["arguments"]),
|
||||
)
|
||||
for tool in message["tool_calls"]
|
||||
]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
else:
|
||||
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
||||
if native_tool_enabled and "tool_calls" in message:
|
||||
logger.warning(
|
||||
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文"
|
||||
)
|
||||
parts = [types.Part.from_text(text=" ")]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
|
||||
elif role == "tool" and not native_tool_enabled:
|
||||
parts = [
|
||||
types.Part.from_function_response(
|
||||
name=message["tool_call_id"],
|
||||
response={
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
)
|
||||
]
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
if gemini_contents and isinstance(gemini_contents[0], types.ModelContent):
|
||||
gemini_contents.pop()
|
||||
|
||||
return gemini_contents
|
||||
|
||||
@staticmethod
|
||||
def _process_content_parts(
|
||||
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||
) -> MessageChain:
|
||||
"""处理内容部分并构建消息链"""
|
||||
finish_reason = result.candidates[0].finish_reason
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
||||
|
||||
if finish_reason in {
|
||||
types.FinishReason.PROHIBITED_CONTENT,
|
||||
types.FinishReason.SPII,
|
||||
types.FinishReason.BLOCKLIST,
|
||||
}:
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
# 防止旧版本SDK不存在IMAGE_SAFETY
|
||||
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
|
||||
if finish_reason == types.FinishReason.IMAGE_SAFETY:
|
||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||
|
||||
if not result_parts:
|
||||
logger.debug(result.candidates)
|
||||
raise Exception("API 返回的内容为空。")
|
||||
|
||||
chain = []
|
||||
part: types.Part
|
||||
|
||||
# 暂时这样Fallback
|
||||
if all(
|
||||
part.inline_data and part.inline_data.mime_type.startswith("image/")
|
||||
for part in result_parts
|
||||
):
|
||||
chain.append(Comp.Plain("这是图片"))
|
||||
for part in result_parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif part.function_call:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_name.append(part.function_call.name)
|
||||
llm_response.tools_call_args.append(part.function_call.args)
|
||||
# gemini 返回的 function_call.id 可能为 None
|
||||
llm_response.tools_call_ids.append(
|
||||
part.function_call.id or part.function_call.name
|
||||
)
|
||||
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
tool = None
|
||||
if tools:
|
||||
tool = tools.get_func_desc_google_genai_style()
|
||||
if not tool:
|
||||
tool = None
|
||||
"""非流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
modalities = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("Image")
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
temperature = payloads.get("temperature", 0.7)
|
||||
|
||||
result: Optional[types.GenerateContentResponse] = None
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
payloads, tools, system_instruction, modalities, temperature
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||
if temperature > 2:
|
||||
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||
temperature += 0.2
|
||||
logger.warning(
|
||||
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
|
||||
)
|
||||
continue
|
||||
|
||||
system_instruction = ""
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "system":
|
||||
system_instruction = message["content"]
|
||||
break
|
||||
|
||||
google_genai_conversation = []
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
if not message["content"]:
|
||||
message["content"] = " "
|
||||
|
||||
google_genai_conversation.append(
|
||||
{"role": "user", "parts": [{"text": message["content"]}]}
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
elif isinstance(message["content"], list):
|
||||
# images
|
||||
parts = []
|
||||
for part in message["content"]:
|
||||
if part["type"] == "text":
|
||||
if not part["text"]:
|
||||
part["text"] = ""
|
||||
parts.append({"text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": part["image_url"]["url"].replace(
|
||||
"data:image/jpeg;base64,", ""
|
||||
), # base64
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "user", "parts": parts})
|
||||
|
||||
elif message["role"] == "assistant":
|
||||
if "content" in message:
|
||||
if not message["content"]:
|
||||
message["content"] = " "
|
||||
google_genai_conversation.append(
|
||||
{"role": "model", "parts": [{"text": message["content"]}]}
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
elif (
|
||||
"Multi-modal output is not supported" in e.message
|
||||
or "Model does not support the requested response modalities"
|
||||
in e.message
|
||||
or "only supports text output" in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
||||
)
|
||||
elif "tool_calls" in message:
|
||||
# tool calls in the last turn
|
||||
parts = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": json.loads(
|
||||
tool_call["function"]["arguments"]
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "model", "parts": parts})
|
||||
elif message["role"] == "tool":
|
||||
parts = []
|
||||
parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": message["tool_call_id"],
|
||||
"response": {
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "user", "parts": parts})
|
||||
modalities = ["Text"]
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
||||
|
||||
modalites = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalites.append("Image")
|
||||
|
||||
loop = True
|
||||
while loop:
|
||||
loop = False
|
||||
result = await self.client.generate_content(
|
||||
contents=google_genai_conversation,
|
||||
model=self.get_model(),
|
||||
system_instruction=system_instruction,
|
||||
tools=tool,
|
||||
modalities=modalites,
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
|
||||
if "Developer instruction is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
|
||||
)
|
||||
system_instruction = ""
|
||||
loop = True
|
||||
|
||||
elif "Function calling is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。"
|
||||
)
|
||||
tool = None
|
||||
loop = True
|
||||
|
||||
elif "Multi-modal output is not supported" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
|
||||
)
|
||||
modalites = ["Text"]
|
||||
loop = True
|
||||
|
||||
elif "candidates" not in result:
|
||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
||||
|
||||
candidates = result["candidates"][0]["content"]["parts"]
|
||||
llm_response = LLMResponse("assistant")
|
||||
chain = []
|
||||
for candidate in candidates:
|
||||
if "text" in candidate:
|
||||
chain.append(Comp.Plain(candidate["text"]))
|
||||
elif "functionCall" in candidate:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
|
||||
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
|
||||
llm_response.tools_call_ids.append(
|
||||
candidate["functionCall"]["name"]
|
||||
) # 没有 tool id
|
||||
elif "inlineData" in candidate:
|
||||
mime_type: str = candidate["inlineData"]["mimeType"]
|
||||
if mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
|
||||
|
||||
llm_response.result_chain = MessageChain(chain=chain)
|
||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
payloads, tools, system_instruction
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=config,
|
||||
)
|
||||
break
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
if chunk.candidates[0].content.parts and any(
|
||||
part.function_call for part in chunk.candidates[0].content.parts
|
||||
):
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
if chunk.text:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
if not chunk.candidates[0].content.parts:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||
else:
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
@@ -337,82 +491,92 @@ class ProviderGoogleGenAI(Provider):
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
llm_response = None
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(keys)
|
||||
|
||||
for i in range(retry):
|
||||
for _ in range(retry):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
return await self._query(payloads, func_tool)
|
||||
except APIError as e:
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
except Exception as e:
|
||||
if "429" in str(e) or "API key not valid" in str(e):
|
||||
keys.remove(chosen_key)
|
||||
if len(keys) > 0:
|
||||
chosen_key = random.choice(keys)
|
||||
logger.info(
|
||||
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: str = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except APIError as e:
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models = await self.client.models.list()
|
||||
return [
|
||||
m.name.replace("models/", "")
|
||||
for m in models
|
||||
if "generateContent" in m.supported_actions
|
||||
]
|
||||
except APIError as e:
|
||||
raise Exception(f"获取模型列表失败: {e.message}")
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
return self.chosen_api_key
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
def get_keys(self) -> list[str]:
|
||||
return self.api_keys
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
async def assemble_context(self, text: str, image_urls: list[str] = None):
|
||||
"""
|
||||
组装上下文。
|
||||
"""
|
||||
if image_urls:
|
||||
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -444,5 +608,4 @@ class ProviderGoogleGenAI(Provider):
|
||||
return ""
|
||||
|
||||
async def terminate(self):
|
||||
await self.client.client.close()
|
||||
logger.info("Google GenAI 适配器已终止。")
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
import urllib.parse
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -23,7 +25,8 @@ class ProviderGSVITTS(TTSProvider):
|
||||
self.emotion = provider_config.get("emotion")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/gsvi_tts_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
||||
params = {"text": text}
|
||||
|
||||
if self.character:
|
||||
|
||||
149
astrbot/core/provider/sources/minimax_tts_api_source.py
Normal file
149
astrbot/core/provider/sources/minimax_tts_api_source.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import aiohttp
|
||||
from typing import Dict, List, Union, AsyncIterator
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.api import logger
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"minimax_tts_api", "MiniMax TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.api_base: str = provider_config.get(
|
||||
"api_base", "https://api.minimax.chat/v1/t2a_v2"
|
||||
)
|
||||
self.group_id: str = provider_config.get("minimax-group-id", "")
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
|
||||
self.is_timber_weight: bool = provider_config.get(
|
||||
"minimax-is-timber-weight", False
|
||||
)
|
||||
self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads(
|
||||
provider_config.get(
|
||||
"minimax-timber-weight",
|
||||
'[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
|
||||
)
|
||||
)
|
||||
|
||||
self.voice_setting: dict = {
|
||||
"speed": provider_config.get("minimax-voice-speed", 1.0),
|
||||
"vol": provider_config.get("minimax-voice-vol", 1.0),
|
||||
"pitch": provider_config.get("minimax-voice-pitch", 0),
|
||||
"voice_id": ""
|
||||
if self.is_timber_weight
|
||||
else provider_config.get("minimax-voice-id", ""),
|
||||
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
|
||||
"latex_read": provider_config.get("minimax-voice-latex", False),
|
||||
"english_normalization": provider_config.get(
|
||||
"minimax-voice-english-normalization", False
|
||||
),
|
||||
}
|
||||
|
||||
self.audio_setting: dict = {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
"format": "mp3",
|
||||
}
|
||||
|
||||
self.concat_base_url: str = f"{self.api_base}?GroupId={self.group_id}"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
"accept": "application/json, text/plain, */*",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
def _build_tts_stream_body(self, text: str):
|
||||
"""构建流式请求体"""
|
||||
dict_body: Dict[str, object] = {
|
||||
"model": self.model_name,
|
||||
"text": text,
|
||||
"stream": True,
|
||||
"language_boost": self.lang_boost,
|
||||
"voice_setting": self.voice_setting,
|
||||
"audio_setting": self.audio_setting,
|
||||
}
|
||||
if self.is_timber_weight:
|
||||
dict_body["timber_weights"] = self.timber_weight
|
||||
|
||||
return json.dumps(dict_body)
|
||||
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
"""进行流式请求"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.concat_base_url,
|
||||
headers=self.headers,
|
||||
data=self._build_tts_stream_body(text),
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
buffer = b""
|
||||
while True:
|
||||
chunk = await response.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk
|
||||
|
||||
while b"\n\n" in buffer:
|
||||
try:
|
||||
message, buffer = buffer.split(b"\n\n", 1)
|
||||
if message.startswith(b"data: "):
|
||||
try:
|
||||
data = json.loads(message[6:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
if audio is not None:
|
||||
yield audio
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse JSON data from SSE message"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
buffer = buffer[-1024:]
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"MiniMax TTS API请求失败: {str(e)}")
|
||||
|
||||
async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
|
||||
"""解码数据流到 audio 比特流"""
|
||||
chunks = []
|
||||
async for chunk in audio_stream:
|
||||
if chunk.strip():
|
||||
chunks.append(bytes.fromhex(chunk.strip()))
|
||||
return b"".join(chunks)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
|
||||
|
||||
try:
|
||||
# 直接将异步生成器传递给 _audio_play 方法
|
||||
audio_stream = self._call_tts_stream(text)
|
||||
audio = await self._audio_play(audio_stream)
|
||||
|
||||
# 结果保存至文件
|
||||
with open(path, "wb") as file:
|
||||
file.write(audio)
|
||||
|
||||
return path
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise e
|
||||
42
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
42
astrbot/core/provider/sources/openai_embedding_source.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from openai import AsyncOpenAI
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"openai_embedding",
|
||||
"OpenAI API Embedding 提供商适配器",
|
||||
provider_type=ProviderType.EMBEDDING,
|
||||
)
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=provider_config.get("embedding_api_key"),
|
||||
base_url=provider_config.get(
|
||||
"embedding_api_base", "https://api.openai.com/v1"
|
||||
),
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
self.dimension = provider_config.get("embedding_dimensions", 1536)
|
||||
|
||||
async def get_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
批量获取文本的嵌入
|
||||
"""
|
||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return self.dimension
|
||||
@@ -21,7 +21,7 @@ from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List, AsyncGenerator
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -195,7 +195,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
# workaround for #1454
|
||||
if isinstance(tool_call.function.arguments, str):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
else:
|
||||
args = tool_call.function.arguments
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
tool_call_ids.append(tool_call.id)
|
||||
@@ -221,14 +225,16 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
@@ -337,11 +343,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
@@ -362,7 +368,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
e = None
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
@@ -376,6 +382,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
@@ -398,7 +405,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
@@ -428,7 +437,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
e = None
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
@@ -443,6 +452,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
@@ -465,7 +475,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
"""
|
||||
@@ -505,7 +517,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -31,7 +33,8 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/openai_tts_api_{uuid.uuid4()}.wav"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name, voice=self.voice, response_format="wav", input=text
|
||||
) as response:
|
||||
|
||||
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
107
astrbot/core/provider/sources/volcengine_tts.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import uuid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import requests
|
||||
from ..provider import TTSProvider
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot import logger
|
||||
|
||||
@register_provider_adapter(
|
||||
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderVolcengineTTS(TTSProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.api_key = provider_config.get("api_key", "")
|
||||
self.appid = provider_config.get("appid", "")
|
||||
self.cluster = provider_config.get("volcengine_cluster", "")
|
||||
self.voice_type = provider_config.get("volcengine_voice_type", "")
|
||||
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
|
||||
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
|
||||
def _build_request_payload(self, text: str) -> dict:
|
||||
return {
|
||||
"app": {
|
||||
"appid": self.appid,
|
||||
"token": self.api_key,
|
||||
"cluster": self.cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4())
|
||||
},
|
||||
"audio": {
|
||||
"voice_type": self.voice_type,
|
||||
"encoding": "mp3",
|
||||
"speed_ratio": self.speed_ratio,
|
||||
"volume_ratio": 1.0,
|
||||
"pitch_ratio": 1.0,
|
||||
},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"text_type": "plain",
|
||||
"operation": "query",
|
||||
"with_frontend": 1,
|
||||
"frontend_type": "unitTson"
|
||||
}
|
||||
}
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
"""异步方法获取语音文件路径"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer; {self.api_key}"
|
||||
}
|
||||
|
||||
payload = self._build_request_payload(text)
|
||||
|
||||
logger.debug(f"请求头: {headers}")
|
||||
logger.debug(f"请求 URL: {self.api_base}")
|
||||
logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.api_base,
|
||||
data=json.dumps(payload),
|
||||
headers=headers,
|
||||
timeout=self.timeout
|
||||
) as response:
|
||||
logger.debug(f"响应状态码: {response.status}")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(f"响应内容: {response_text[:200]}...")
|
||||
|
||||
if response.status == 200:
|
||||
resp_data = json.loads(response_text)
|
||||
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: open(file_path, "wb").write(audio_data)
|
||||
)
|
||||
|
||||
return file_path
|
||||
else:
|
||||
error_msg = resp_data.get("message", "未知错误")
|
||||
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
|
||||
else:
|
||||
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
|
||||
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
|
||||
raise Exception(f"火山引擎 TTS 异常: {str(e)}")
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -50,7 +51,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -61,7 +63,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -53,7 +54,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -64,7 +66,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from typing import List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class SimpleOpenAIEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
api_key,
|
||||
api_base=None,
|
||||
) -> None:
|
||||
self.client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||
self.model = model
|
||||
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
"""
|
||||
获取文本的嵌入
|
||||
"""
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
@@ -1,94 +0,0 @@
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
|
||||
class KnowledgeDBManager:
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = "data/knowledge_db/"
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
self.store_insts: Dict[str, Store] = {}
|
||||
for name, cfg in self.config.items():
|
||||
if cfg["strategy"] == "embedding":
|
||||
logger.info(f"加载 Chroma Vector Store:{name}")
|
||||
try:
|
||||
from .store.chroma_db import ChromaVectorStore
|
||||
except ImportError as ie:
|
||||
logger.error(f"{ie} 可能未安装 chromadb 库。")
|
||||
continue
|
||||
self.store_insts[name] = ChromaVectorStore(
|
||||
name, cfg["embedding_config"]
|
||||
)
|
||||
else:
|
||||
logger.error(f"不支持的策略:{cfg['strategy']}")
|
||||
|
||||
async def list_knowledge_db(self) -> List[str]:
|
||||
return [
|
||||
f
|
||||
for f in os.listdir(self.db_path)
|
||||
if os.path.isfile(os.path.join(self.db_path, f))
|
||||
]
|
||||
|
||||
async def create_knowledge_db(self, name: str, config: Dict):
|
||||
"""
|
||||
config 格式:
|
||||
```
|
||||
{
|
||||
"strategy": "embedding", # 目前只支持 embedding
|
||||
"chunk_method": {
|
||||
"strategy": "fixed",
|
||||
"chunk_size": 100,
|
||||
"overlap_size": 10
|
||||
},
|
||||
"embedding_config": {
|
||||
"strategy": "openai",
|
||||
"base_url": "",
|
||||
"model": "",
|
||||
"api_key": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
if name in self.config:
|
||||
raise ValueError(f"知识库已存在:{name}")
|
||||
|
||||
self.config[name] = config
|
||||
self.astrbot_config["knowledge_db"] = self.config
|
||||
self.astrbot_config.save_config()
|
||||
|
||||
async def insert_record(self, name: str, text: str):
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
ret = []
|
||||
match self.config[name]["chunk_method"]["strategy"]:
|
||||
case "fixed":
|
||||
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
|
||||
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
|
||||
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
|
||||
case _:
|
||||
pass
|
||||
|
||||
for chunk in ret:
|
||||
await self.store_insts[name].save(chunk)
|
||||
|
||||
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
inst = self.store_insts[name]
|
||||
return await inst.query(query, top_n)
|
||||
|
||||
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunks.append(text[start:end])
|
||||
start += chunk_size - chunk_overlap
|
||||
return chunks
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
class Store:
|
||||
async def save(self, text: str):
|
||||
pass
|
||||
|
||||
async def query(self, query: str, top_n: int = 3) -> List[str]:
|
||||
pass
|
||||
@@ -1,42 +0,0 @@
|
||||
import chromadb
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path="data/long_term_memory_chroma.db"
|
||||
)
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
if embedding_cfg["strategy"] == "openai":
|
||||
self.embedding = SimpleOpenAIEmbedding(
|
||||
model=embedding_cfg["model"],
|
||||
api_key=embedding_cfg["api_key"],
|
||||
api_base=embedding_cfg.get("base_url", None),
|
||||
)
|
||||
|
||||
async def save(self, text: str, metadata: Dict = None):
|
||||
logger.debug(f"Saving text: {text}")
|
||||
embedding = await self.embedding.get_embedding(text)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
ids=str(uuid.uuid4()),
|
||||
embeddings=embedding,
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, query: str, top_n=3, metadata_filter: Dict = None
|
||||
) -> List[str]:
|
||||
embedding = await self.embedding.get_embedding(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=embedding, n_results=top_n, where=metadata_filter
|
||||
)
|
||||
return results["documents"][0]
|
||||
@@ -5,6 +5,7 @@
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def load_config(namespace: str) -> Union[dict, bool]:
|
||||
@@ -13,7 +14,7 @@ def load_config(namespace: str) -> Union[dict, bool]:
|
||||
namespace: str, 配置的唯一识别符,也就是配置文件的名字。
|
||||
返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
@@ -43,7 +44,10 @@ def put_config(namespace: str, name: str, key: str, value, description: str):
|
||||
raise ValueError("key 只支持 str 类型。")
|
||||
if not isinstance(value, (str, int, float, bool, list)):
|
||||
raise ValueError("value 只支持 str, int, float, bool, list 类型。")
|
||||
path = f"data/config/{namespace}.json"
|
||||
|
||||
config_dir = os.path.join(get_astrbot_data_path(), "config")
|
||||
path = os.path.join(config_dir, f"{namespace}.json")
|
||||
|
||||
if not os.path.exists(path):
|
||||
with open(path, "w", encoding="utf-8-sig") as f:
|
||||
f.write("{}")
|
||||
@@ -71,7 +75,7 @@ def update_config(namespace: str, key: str, value):
|
||||
key: str, 配置项的键。
|
||||
value: str, int, float, bool, list, 配置项的值。
|
||||
"""
|
||||
path = f"data/config/{namespace}.json"
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。")
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
|
||||
@@ -16,7 +16,6 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
@@ -42,6 +41,8 @@ class Context:
|
||||
|
||||
platform_manager: PlatformManager = None
|
||||
|
||||
registered_web_apis: list = []
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_star_manager = None
|
||||
@@ -54,14 +55,12 @@ class Context:
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
@@ -126,11 +125,8 @@ class Context:
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
"""通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.inst_map.get(provider_id)
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||
@@ -301,3 +297,12 @@ class Context:
|
||||
注册一个异步任务。
|
||||
"""
|
||||
self._register_tasks.append(task)
|
||||
|
||||
def register_web_api(
|
||||
self, route: str, view_handler: Awaitable, methods: list, desc: str
|
||||
):
|
||||
for idx, api in enumerate(self.registered_web_apis):
|
||||
if api[0] == route and methods == api[2]:
|
||||
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
|
||||
return
|
||||
self.registered_web_apis.append((route, view_handler, methods, desc))
|
||||
|
||||
@@ -113,7 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||
)
|
||||
raise ValueError(
|
||||
f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
||||
+ tree
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import heapq
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
@@ -8,100 +7,66 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
"""用于存储所有的 Star Handler"""
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
"""用于快速查找。key 是 handler_full_name"""
|
||||
_handlers = []
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
self._handlers: List[StarHandlerMetadata] = []
|
||||
|
||||
def append(self, handler: StarHandlerMetadata):
|
||||
"""添加一个 Handler"""
|
||||
"""添加一个 Handler,并保持按优先级有序"""
|
||||
if "priority" not in handler.extras_configs:
|
||||
handler.extras_configs["priority"] = 0
|
||||
|
||||
heapq.heappush(self._handlers, (-handler.extras_configs["priority"], handler))
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
self._handlers.append(handler)
|
||||
self._handlers.sort(key=lambda h: -h.extras_configs["priority"])
|
||||
|
||||
def _print_handlers(self):
|
||||
"""打印所有的 Handler"""
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过事件类型获取 Handler
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
only_activated: 是否只返回已激活的插件的处理器
|
||||
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
|
||||
|
||||
Returns:
|
||||
List[StarHandlerMetadata]: 处理器列表
|
||||
"""
|
||||
handlers = []
|
||||
for _, handler in self._handlers:
|
||||
for handler in self._handlers:
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
|
||||
# 只激活的插件处理器
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
|
||||
# 平台兼容性过滤
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
continue
|
||||
|
||||
handlers.append(handler)
|
||||
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
"""通过 Handler 的全名获取 Handler"""
|
||||
return self.star_handlers_map.get(full_name, None)
|
||||
|
||||
def get_handlers_by_module_name(
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过模块名获取 Handler"""
|
||||
return [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
handler for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
def clear(self):
|
||||
"""清空所有的 Handler"""
|
||||
self.star_handlers_map.clear()
|
||||
self._handlers.clear()
|
||||
|
||||
def remove(self, handler: StarHandlerMetadata):
|
||||
"""删除一个 Handler"""
|
||||
# self._handlers.remove(handler)
|
||||
for i, h in enumerate(self._handlers):
|
||||
if h[1] == handler:
|
||||
self._handlers.pop(i)
|
||||
break
|
||||
try:
|
||||
del self.star_handlers_map[handler.handler_full_name]
|
||||
except KeyError:
|
||||
pass
|
||||
self.star_handlers_map.pop(handler.handler_full_name, None)
|
||||
self._handlers = [h for h in self._handlers if h != handler]
|
||||
|
||||
def __iter__(self):
|
||||
"""使 StarHandlerRegistry 支持迭代"""
|
||||
return (handler for _, handler in self._handlers)
|
||||
return iter(self._handlers)
|
||||
|
||||
def __len__(self):
|
||||
"""返回 Handler 的数量"""
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
|
||||
@@ -2,49 +2,59 @@
|
||||
插件的重载、启停、安装、卸载等操作。
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
import asyncio
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .filter.permission import PermissionTypeFilter, PermissionType
|
||||
import yaml
|
||||
|
||||
from astrbot.core import logger, pip_installer, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
|
||||
from . import StarMetadata
|
||||
from .context import Context
|
||||
from .filter.permission import PermissionType, PermissionTypeFilter
|
||||
from .star import star_map, star_registry
|
||||
from .star_handler import star_handlers_registry
|
||||
from .updator import PluginUpdator
|
||||
|
||||
try:
|
||||
from watchfiles import PythonFilter, awatch
|
||||
except ImportError:
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
logger.warning("未安装 watchfiles,无法实现插件的热重载。")
|
||||
|
||||
try:
|
||||
import nh3
|
||||
except ImportError:
|
||||
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
|
||||
nh3 = None
|
||||
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, context: Context, config: AstrBotConfig):
|
||||
self.updator = PluginUpdator(config["plugin_repo_mirror"])
|
||||
self.updator = PluginUpdator()
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
|
||||
)
|
||||
)
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
"""存储插件的路径。即 data/plugins"""
|
||||
self.plugin_config_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/config"
|
||||
)
|
||||
)
|
||||
self.plugin_config_path = get_astrbot_config_path()
|
||||
"""存储插件配置的路径。data/config"""
|
||||
self.reserved_plugin_path = os.path.abspath(
|
||||
os.path.join(
|
||||
@@ -56,6 +66,58 @@ class PluginManager:
|
||||
"""插件配置 Schema 文件名"""
|
||||
|
||||
self.failed_plugin_info = ""
|
||||
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
|
||||
asyncio.create_task(self._watch_plugins_changes())
|
||||
|
||||
async def _watch_plugins_changes(self):
|
||||
"""监视插件文件变化"""
|
||||
try:
|
||||
async for changes in awatch(
|
||||
self.plugin_store_path,
|
||||
self.reserved_plugin_path,
|
||||
watch_filter=PythonFilter(),
|
||||
recursive=True,
|
||||
):
|
||||
# 处理文件变化
|
||||
await self._handle_file_changes(changes)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"插件热重载监视任务异常: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _handle_file_changes(self, changes):
|
||||
"""处理文件变化"""
|
||||
logger.info(f"检测到文件变化: {changes}")
|
||||
plugins_to_check = []
|
||||
|
||||
for star in star_registry:
|
||||
if not star.activated:
|
||||
continue
|
||||
if star.root_dir_name is None:
|
||||
continue
|
||||
if star.reserved:
|
||||
plugin_dir_path = os.path.join(
|
||||
self.reserved_plugin_path, star.root_dir_name
|
||||
)
|
||||
else:
|
||||
plugin_dir_path = os.path.join(
|
||||
self.plugin_store_path, star.root_dir_name
|
||||
)
|
||||
plugins_to_check.append((plugin_dir_path, star.name))
|
||||
reloaded_plugins = set()
|
||||
for change in changes:
|
||||
_, file_path = change
|
||||
for plugin_dir_path, plugin_name in plugins_to_check:
|
||||
if (
|
||||
os.path.commonpath([plugin_dir_path])
|
||||
== os.path.commonpath([plugin_dir_path, file_path])
|
||||
and plugin_name not in reloaded_plugins
|
||||
):
|
||||
logger.info(f"检测到插件 {plugin_name} 文件变化,正在重载...")
|
||||
await self.reload(plugin_name)
|
||||
reloaded_plugins.add(plugin_name)
|
||||
break
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
"""获取指定模块(可以理解为一个 python 文件)下所有的类"""
|
||||
@@ -104,7 +166,7 @@ class PluginManager:
|
||||
plugins.extend(_p)
|
||||
return plugins
|
||||
|
||||
def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
async def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
"""检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
"""
|
||||
@@ -123,7 +185,7 @@ class PluginManager:
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
pip_installer.install(requirements_path=pth)
|
||||
await pip_installer.install(requirements_path=pth)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
@@ -345,7 +407,7 @@ class PluginManager:
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# 尝试安装依赖
|
||||
self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
await self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -580,16 +642,17 @@ class PluginManager:
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path):
|
||||
if os.path.exists(readme_path) and nh3:
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
cleaned_content = nh3.clean(readme_content)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||
plugin_info = {"repo": plugin.repo, "readme": cleaned_content}
|
||||
|
||||
return plugin_info
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -6,7 +8,7 @@ from astrbot.api.platform import MessageMember, AstrBotMessage
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star import star_map
|
||||
from pathlib import Path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
class StarTools:
|
||||
@@ -180,7 +182,7 @@ class StarTools:
|
||||
|
||||
plugin_name = metadata.name
|
||||
|
||||
data_dir = Path("data/plugin_data") / plugin_name
|
||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -6,16 +6,13 @@ from ..updator import RepoZipUpdator
|
||||
from astrbot.core.utils.io import remove_dir, on_error
|
||||
from ..star.star import StarMetadata
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_plugin_path
|
||||
|
||||
|
||||
class PluginUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.plugin_store_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"
|
||||
)
|
||||
)
|
||||
self.plugin_store_path = get_astrbot_plugin_path()
|
||||
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
@@ -6,6 +6,7 @@ from .zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
|
||||
class AstrBotUpdator(RepoZipUpdator):
|
||||
@@ -16,9 +17,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.MAIN_PATH = os.path.abspath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")
|
||||
)
|
||||
self.MAIN_PATH = get_astrbot_path()
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
@@ -45,13 +44,26 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
def _reboot(self, delay: int = 3):
|
||||
"""重启当前程序
|
||||
在指定的延迟后,终止所有子进程并重新启动程序
|
||||
这里只能使用 os.exec* 来重启程序
|
||||
"""
|
||||
py = sys.executable
|
||||
time.sleep(delay)
|
||||
self.terminate_child_processes()
|
||||
py = py.replace(" ", "\\ ")
|
||||
if os.name == "nt":
|
||||
py = f'"{sys.executable}"'
|
||||
else:
|
||||
py = sys.executable
|
||||
|
||||
try:
|
||||
os.execl(py, py, *sys.argv)
|
||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||
if os.name == "nt":
|
||||
args = [
|
||||
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
||||
]
|
||||
else:
|
||||
args = sys.argv[1:]
|
||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||
else:
|
||||
os.execl(sys.executable, py, *sys.argv)
|
||||
except Exception as e:
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
@@ -67,6 +79,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest)
|
||||
file_url = None
|
||||
|
||||
if os.environ.get("ASTRBOT_CLI"):
|
||||
raise Exception("不支持更新CLI启动的AstrBot") # 避免版本管理混乱
|
||||
|
||||
if latest:
|
||||
latest_version = update_data[0]["tag_name"]
|
||||
if self.compare_version(VERSION, latest_version) >= 0:
|
||||
|
||||
41
astrbot/core/utils/astrbot_path.py
Normal file
41
astrbot/core/utils/astrbot_path.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Astrbot统一路径获取
|
||||
|
||||
项目路径:固定为源码所在路径
|
||||
根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_astrbot_path() -> str:
|
||||
"""获取Astrbot项目路径"""
|
||||
return os.path.realpath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../")
|
||||
)
|
||||
|
||||
|
||||
def get_astrbot_root() -> str:
|
||||
"""获取Astrbot根目录路径"""
|
||||
if path := os.environ.get("ASTRBOT_ROOT"):
|
||||
return os.path.realpath(path)
|
||||
else:
|
||||
return os.path.realpath(os.getcwd())
|
||||
|
||||
|
||||
def get_astrbot_data_path() -> str:
|
||||
"""获取Astrbot数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_root(), "data"))
|
||||
|
||||
|
||||
def get_astrbot_config_path() -> str:
|
||||
"""获取Astrbot配置文件路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "config"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
@@ -14,6 +14,7 @@ import certifi
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def on_error(func, path, exc_info):
|
||||
@@ -49,11 +50,11 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
|
||||
|
||||
def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
for f in os.listdir(temp_dir):
|
||||
path = os.path.join(temp_dir, f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600 * 12:
|
||||
@@ -63,7 +64,7 @@ def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
p = os.path.join(temp_dir, f"{timestamp}.jpg")
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
@@ -201,28 +202,29 @@ def get_local_ip_addresses():
|
||||
|
||||
|
||||
async def get_dashboard_version():
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if os.path.exists(dist_dir):
|
||||
version_file = os.path.join(dist_dir, "assets", "version")
|
||||
if os.path.exists(version_file):
|
||||
with open(version_file, "r") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
async def download_dashboard():
|
||||
async def download_dashboard(path: str = None, extract_path: str = "data"):
|
||||
"""下载管理面板文件"""
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
)
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = (
|
||||
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
)
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
)
|
||||
await download_file(dashboard_release_url, path, show_progress=True)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
with zipfile.ZipFile(path, "r") as z:
|
||||
z.extractall(extract_path)
|
||||
|
||||
36
astrbot/core/utils/log_pipe.py
Normal file
36
astrbot/core/utils/log_pipe.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import threading
|
||||
import os
|
||||
from logging import Logger
|
||||
|
||||
|
||||
class LogPipe(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
level,
|
||||
logger: Logger,
|
||||
identifier=None,
|
||||
callback=None,
|
||||
):
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.level = level
|
||||
self.fd_read, self.fd_write = os.pipe()
|
||||
self.identifier = identifier
|
||||
self.logger = logger
|
||||
self.callback = callback
|
||||
self.reader = os.fdopen(self.fd_read)
|
||||
self.start()
|
||||
|
||||
def fileno(self):
|
||||
return self.fd_write
|
||||
|
||||
def run(self):
|
||||
for line in iter(self.reader.readline, ""):
|
||||
if self.callback:
|
||||
self.callback(line.strip())
|
||||
self.logger.log(self.level, f"[{self.identifier}] {line.strip()}")
|
||||
|
||||
self.reader.close()
|
||||
|
||||
def close(self):
|
||||
os.close(self.fd_write)
|
||||
@@ -1,10 +1,42 @@
|
||||
import aiohttp
|
||||
import sys
|
||||
import os
|
||||
import socket
|
||||
import uuid
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core import db_helper, logger
|
||||
|
||||
|
||||
class Metric:
|
||||
_iid_cache = None
|
||||
|
||||
@staticmethod
|
||||
def get_installation_id():
|
||||
"""获取或创建一个唯一的安装ID"""
|
||||
if Metric._iid_cache is not None:
|
||||
return Metric._iid_cache
|
||||
|
||||
config_dir = os.path.join(os.path.expanduser("~"), ".astrbot")
|
||||
id_file = os.path.join(config_dir, ".installation_id")
|
||||
|
||||
if os.path.exists(id_file):
|
||||
try:
|
||||
with open(id_file, "r") as f:
|
||||
Metric._iid_cache = f.read().strip()
|
||||
return Metric._iid_cache
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
installation_id = str(uuid.uuid4())
|
||||
with open(id_file, "w") as f:
|
||||
f.write(installation_id)
|
||||
Metric._iid_cache = installation_id
|
||||
return installation_id
|
||||
except Exception:
|
||||
Metric._iid_cache = "null"
|
||||
return "null"
|
||||
|
||||
@staticmethod
|
||||
async def upload(**kwargs):
|
||||
"""
|
||||
@@ -16,6 +48,14 @@ class Metric:
|
||||
kwargs["v"] = VERSION
|
||||
kwargs["os"] = sys.platform
|
||||
payload = {"metrics_data": kwargs}
|
||||
try:
|
||||
kwargs["hn"] = socket.gethostname()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
kwargs["iid"] = Metric.get_installation_id()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if "adapter_name" in kwargs:
|
||||
db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1})
|
||||
|
||||
72
astrbot/core/utils/path_util.py
Normal file
72
astrbot/core/utils/path_util.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
def path_Mapping(mappings, srcPath: str) -> str:
|
||||
"""路径映射处理函数。尝试支援 Windows 和 Linux 的路径映射。
|
||||
Args:
|
||||
mappings: 映射规则列表
|
||||
srcPath: 原路径
|
||||
Returns:
|
||||
str: 处理后的路径
|
||||
"""
|
||||
for mapping in mappings:
|
||||
rule = mapping.split(":")
|
||||
if len(rule) == 2:
|
||||
from_, to_ = mapping.split(":")
|
||||
elif len(rule) > 4 or len(rule) == 1:
|
||||
# 切割后大于4个项目,或者只有1个项目,那肯定是错误的,只能是2,3,4个项目
|
||||
logger.warning(f"路径映射规则错误: {mapping}")
|
||||
continue
|
||||
else:
|
||||
# rule.len == 3 or 4
|
||||
if os.path.exists(rule[0] + ":" + rule[1]):
|
||||
# 前面两个项目合并路径存在,说明是本地Window路径。后面一个或两个项目组成的路径本地大概率无法解析,直接拼接
|
||||
from_ = rule[0] + ":" + rule[1]
|
||||
if len(rule) == 3:
|
||||
to_ = rule[2]
|
||||
else:
|
||||
to_ = rule[2] + ":" + rule[3]
|
||||
else:
|
||||
# 前面两个项目合并路径不存在,说明第一个项目是本地Linux路径,后面一个或两个项目直接拼接。
|
||||
from_ = rule[0]
|
||||
if len(rule) == 3:
|
||||
to_ = rule[1] + ":" + rule[2]
|
||||
else:
|
||||
# 这种情况下存在四个项目,说明规则也是错误的
|
||||
logger.warning(f"路径映射规则错误: {mapping}")
|
||||
continue
|
||||
|
||||
from_ = from_.removesuffix("/")
|
||||
from_ = from_.removesuffix("\\")
|
||||
to_ = to_.removesuffix("/")
|
||||
to_ = to_.removesuffix("\\")
|
||||
# logger.debug(f"\t路径映射-规则(处理): {from_} -> {to_}")
|
||||
|
||||
url = srcPath.removeprefix("file://")
|
||||
if url.startswith(from_):
|
||||
srcPath = url.replace(from_, to_, 1)
|
||||
if ":" in srcPath:
|
||||
# Windows路径处理
|
||||
srcPath = srcPath.replace("/", "\\")
|
||||
else:
|
||||
has_replaced_processed = False
|
||||
if srcPath.startswith("."):
|
||||
# 相对路径处理。如果是相对路径,可能是Linux路径,也可能是Windows路径
|
||||
sign = srcPath[1]
|
||||
# 处理两个点的情况
|
||||
if sign == ".":
|
||||
sign = srcPath[2]
|
||||
if sign == "/":
|
||||
srcPath = srcPath.replace("\\", "/")
|
||||
has_replaced_processed = True
|
||||
elif sign == "\\":
|
||||
srcPath = srcPath.replace("/", "\\")
|
||||
has_replaced_processed = True
|
||||
if not has_replaced_processed:
|
||||
# 如果不是相对路径或不能处理,默认按照Linux路径处理
|
||||
srcPath = srcPath.replace("\\", "/")
|
||||
logger.info(f"路径映射: {url} -> {srcPath}")
|
||||
return srcPath
|
||||
return srcPath
|
||||
@@ -1,14 +1,15 @@
|
||||
import logging
|
||||
from pip import main as pip_main
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
class PipInstaller:
|
||||
def __init__(self, pip_install_arg: str):
|
||||
def __init__(self, pip_install_arg: str, pypi_index_url: str = None):
|
||||
self.pip_install_arg = pip_install_arg
|
||||
self.pypi_index_url = pypi_index_url
|
||||
|
||||
def install(
|
||||
async def install(
|
||||
self,
|
||||
package_name: str = None,
|
||||
requirements_path: str = None,
|
||||
@@ -20,21 +21,37 @@ class PipInstaller:
|
||||
elif requirements_path:
|
||||
args.extend(["-r", requirements_path])
|
||||
|
||||
if not mirror:
|
||||
mirror = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
||||
|
||||
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", mirror])
|
||||
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
|
||||
|
||||
if self.pip_install_arg:
|
||||
args.extend(self.pip_install_arg.split())
|
||||
|
||||
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"pip", *args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
result_code = pip_main(args)
|
||||
assert process.stdout is not None
|
||||
async for line in process.stdout:
|
||||
logger.info(line.decode().strip())
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
await process.wait()
|
||||
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
if process.returncode != 0:
|
||||
raise Exception(f"安装失败,错误码:{process.returncode}")
|
||||
except FileNotFoundError:
|
||||
# 没有 pip
|
||||
from pip import main as pip_main
|
||||
result_code = await asyncio.to_thread(pip_main, args)
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
|
||||
@@ -97,8 +97,8 @@ class SessionFilter:
|
||||
|
||||
class DefaultSessionFilter(SessionFilter):
|
||||
def filter(self, event: AstrMessageEvent) -> str:
|
||||
"""默认实现,返回发送者的 ID 作为会话标识符"""
|
||||
return event.get_sender_id()
|
||||
"""默认实现,返回统一消息来源字符串作为会话标识符"""
|
||||
return event.unified_msg_origin
|
||||
|
||||
|
||||
class SessionWaiter:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user