Compare commits
147 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
021ca8175b | ||
|
|
39d6207fe1 | ||
|
|
23ce687229 | ||
|
|
3715312fd2 | ||
|
|
8196922cac | ||
|
|
8089ad91da | ||
|
|
2930cc3fd8 | ||
|
|
0e841a8b25 | ||
|
|
67fa1611cc | ||
|
|
91136bb9f7 | ||
|
|
7c050d1adc | ||
|
|
a0690a6afc | ||
|
|
c51609b261 | ||
|
|
72148f66eb | ||
|
|
a04993a2bb | ||
|
|
74f845b06d | ||
|
|
50144ddcae | ||
|
|
94bf3b8195 | ||
|
|
e190bbeeed | ||
|
|
92abc43c9d | ||
|
|
c8e34ff26f | ||
|
|
630df3e76e | ||
|
|
bdbf382201 | ||
|
|
00eefc82db | ||
|
|
dc97080837 | ||
|
|
0b7fc29ac4 | ||
|
|
ff998fdd8d | ||
|
|
d7461ed54c | ||
|
|
3ce577acf9 | ||
|
|
50b1dccff3 | ||
|
|
c33e7e30d4 | ||
|
|
bc7f01ba36 | ||
|
|
2ce653caad | ||
|
|
0d850d7b22 | ||
|
|
a2be155b8e | ||
|
|
68aa107689 | ||
|
|
23096ed3a5 | ||
|
|
90a65c35c1 | ||
|
|
3d88827a95 | ||
|
|
40a0a8df5a | ||
|
|
20f7129c0b | ||
|
|
0e962e95dd | ||
|
|
07ba9c772c | ||
|
|
0622d88b22 | ||
|
|
594f0fed55 | ||
|
|
04b0d9b88d | ||
|
|
1f2af8ef94 | ||
|
|
598ea2d857 | ||
|
|
6dd9bbb516 | ||
|
|
3cd0b47dc6 | ||
|
|
65c71b5f20 | ||
|
|
1152b11202 | ||
|
|
51246ea31b | ||
|
|
7e5592dd32 | ||
|
|
c6b28caebf | ||
|
|
ca002f6fff | ||
|
|
14ec392091 | ||
|
|
5e2eb91ac0 | ||
|
|
c1626613ce | ||
|
|
42042d9e73 | ||
|
|
22c3b53ab8 | ||
|
|
090c32c90e | ||
|
|
4f4a9b9e55 | ||
|
|
6c7d7c9015 | ||
|
|
562e62a8c0 | ||
|
|
0823f7aa48 | ||
|
|
eb201c0420 | ||
|
|
6cfed9a39d | ||
|
|
33618c4a6b | ||
|
|
ace0a7c219 | ||
|
|
f7d018cf94 | ||
|
|
8ae2a556e4 | ||
|
|
4188deb386 | ||
|
|
82cf4ed909 | ||
|
|
88fc437abc | ||
|
|
57f868cab1 | ||
|
|
6cb5527894 | ||
|
|
016783a1e5 | ||
|
|
594ccff9c8 | ||
|
|
30792f0584 | ||
|
|
8f021eb35a | ||
|
|
1969abc340 | ||
|
|
b1b53ab983 | ||
|
|
9b5af23982 | ||
|
|
4cedc6d3c8 | ||
|
|
4e9cce76da | ||
|
|
9b004f3d2f | ||
|
|
9430e3090d | ||
|
|
ba44f9117b | ||
|
|
eb56710a72 | ||
|
|
38e3f27899 | ||
|
|
3c58d96db5 | ||
|
|
a6be0cc135 | ||
|
|
a53510bc41 | ||
|
|
1fd482e899 | ||
|
|
2f130ba009 | ||
|
|
e6d9db9395 | ||
|
|
e0ac743cdb | ||
|
|
b0d3fc11f0 | ||
|
|
7e0a50fbf2 | ||
|
|
59df244173 | ||
|
|
deb31a02cf | ||
|
|
e3aa1315ae | ||
|
|
65bc5efa19 | ||
|
|
abc4bc24b4 | ||
|
|
5df3f06f83 | ||
|
|
0e1de82bd7 | ||
|
|
f31e41b3f1 | ||
|
|
fe8d2718c4 | ||
|
|
8afefada0a | ||
|
|
745e1c37c0 | ||
|
|
fdb5988cec | ||
|
|
36ffcf3cc3 | ||
|
|
a0f8f3ae32 | ||
|
|
130f52f315 | ||
|
|
a05868cc45 | ||
|
|
2fc77aed15 | ||
|
|
c56edb4da6 | ||
|
|
6672190760 | ||
|
|
f122b17097 | ||
|
|
2c5f68e696 | ||
|
|
e1ca645a32 | ||
|
|
333bf56ddc | ||
|
|
b240594859 | ||
|
|
beccae933f | ||
|
|
e6aa1d2c54 | ||
|
|
5e808bab65 | ||
|
|
361d78247b | ||
|
|
3550103e45 | ||
|
|
8b0d4d4de4 | ||
|
|
dc71c04b67 | ||
|
|
a0254ed817 | ||
|
|
2563ecf3c5 | ||
|
|
c04738d9fe | ||
|
|
1266b4d086 | ||
|
|
99cf0a1522 | ||
|
|
98a75e923d | ||
|
|
ad96d676e6 | ||
|
|
79333bbc35 | ||
|
|
5c5b0f4fde | ||
|
|
ed6cdfedbb | ||
|
|
23f13ef05f | ||
|
|
f9c59d9706 | ||
|
|
e1cec42227 | ||
|
|
8d79c50d53 | ||
|
|
d77830b97f | ||
|
|
394540f689 |
@@ -1,9 +1,8 @@
|
|||||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
||||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||||
# github acions
|
# github actions
|
||||||
.github/
|
.github/
|
||||||
.*ignore
|
.*ignore
|
||||||
.git/
|
|
||||||
# User-specific stuff
|
# User-specific stuff
|
||||||
.idea/
|
.idea/
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
@@ -15,7 +14,6 @@ env/
|
|||||||
venv*/
|
venv*/
|
||||||
ENV/
|
ENV/
|
||||||
.conda/
|
.conda/
|
||||||
README*.md
|
|
||||||
dashboard/
|
dashboard/
|
||||||
data/
|
data/
|
||||||
changelogs/
|
changelogs/
|
||||||
|
|||||||
9
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
9
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -16,7 +16,7 @@ body:
|
|||||||
|
|
||||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||||
|
|
||||||
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
|
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: plugin-info
|
id: plugin-info
|
||||||
@@ -26,12 +26,13 @@ body:
|
|||||||
value: |
|
value: |
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"name": "插件名",
|
"name": "插件名,请以 astrbot_plugin_ 开头",
|
||||||
"desc": "插件介绍",
|
"display_name": "用于展示的插件名,方便人类阅读",
|
||||||
|
"desc": "插件的简短介绍",
|
||||||
"author": "作者名",
|
"author": "作者名",
|
||||||
"repo": "插件仓库链接",
|
"repo": "插件仓库链接",
|
||||||
"tags": [],
|
"tags": [],
|
||||||
"social_link": ""
|
"social_link": "",
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
validations:
|
validations:
|
||||||
|
|||||||
6
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
6
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -6,13 +6,13 @@ body:
|
|||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: |
|
value: |
|
||||||
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。
|
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: 发生了什么
|
label: 发生了什么
|
||||||
description: 描述你遇到的异常
|
description: 描述你遇到的异常
|
||||||
placeholder: >
|
placeholder: >
|
||||||
一个清晰且具体的描述这个异常是什么。
|
一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
label: 报错日志
|
label: 报错日志
|
||||||
description: >
|
description: >
|
||||||
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!
|
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||||
placeholder: >
|
placeholder: >
|
||||||
请提供完整的报错日志或截图。
|
请提供完整的报错日志或截图。
|
||||||
validations:
|
validations:
|
||||||
|
|||||||
13
.github/workflows/dashboard_ci.yml
vendored
13
.github/workflows/dashboard_ci.yml
vendored
@@ -13,11 +13,18 @@ jobs:
|
|||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v6
|
||||||
|
with:
|
||||||
|
node-version: 'latest'
|
||||||
|
|
||||||
- name: npm install, build
|
- name: npm install, build
|
||||||
run: |
|
run: |
|
||||||
cd dashboard
|
cd dashboard
|
||||||
npm install
|
npm install pnpm -g
|
||||||
npm run build
|
pnpm install
|
||||||
|
pnpm i --save-dev @types/markdown-it
|
||||||
|
pnpm run build
|
||||||
|
|
||||||
- name: Inject Commit SHA
|
- name: Inject Commit SHA
|
||||||
id: get_sha
|
id: get_sha
|
||||||
@@ -29,7 +36,7 @@ jobs:
|
|||||||
zip -r dist.zip dist
|
zip -r dist.zip dist
|
||||||
|
|
||||||
- name: Archive production artifacts
|
- name: Archive production artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: dist-without-markdown
|
name: dist-without-markdown
|
||||||
path: |
|
path: |
|
||||||
|
|||||||
62
.gitignore
vendored
62
.gitignore
vendored
@@ -1,35 +1,49 @@
|
|||||||
|
# Python related
|
||||||
__pycache__
|
__pycache__
|
||||||
botpy.log
|
.mypy_cache
|
||||||
.vscode
|
|
||||||
.venv*
|
.venv*
|
||||||
.idea
|
.conda/
|
||||||
data_v2.db
|
uv.lock
|
||||||
data_v3.db
|
|
||||||
configs/session
|
|
||||||
configs/config.yaml
|
|
||||||
**/.DS_Store
|
|
||||||
temp
|
|
||||||
cmd_config.json
|
|
||||||
data
|
|
||||||
cookies.json
|
|
||||||
logs/
|
|
||||||
addons/plugins
|
|
||||||
.coverage
|
.coverage
|
||||||
|
|
||||||
|
# IDE and editors
|
||||||
|
.vscode
|
||||||
|
.idea
|
||||||
|
|
||||||
|
# Logs and temporary files
|
||||||
|
botpy.log
|
||||||
|
logs/
|
||||||
|
temp
|
||||||
|
cookies.json
|
||||||
|
|
||||||
|
# Data files
|
||||||
|
data_v2.db
|
||||||
|
data_v3.db
|
||||||
|
data
|
||||||
|
configs/session
|
||||||
|
configs/config.yaml
|
||||||
|
cmd_config.json
|
||||||
|
|
||||||
|
# Plugins and packages
|
||||||
|
addons/plugins
|
||||||
|
packages/python_interpreter/workplace
|
||||||
tests/astrbot_plugin_openai
|
tests/astrbot_plugin_openai
|
||||||
chroma
|
|
||||||
|
# Dashboard
|
||||||
dashboard/node_modules/
|
dashboard/node_modules/
|
||||||
dashboard/dist/
|
dashboard/dist/
|
||||||
.DS_Store
|
|
||||||
package-lock.json
|
package-lock.json
|
||||||
package.json
|
package.json
|
||||||
venv/*
|
|
||||||
packages/python_interpreter/workplace
|
|
||||||
.venv/*
|
|
||||||
.conda/
|
|
||||||
.idea
|
|
||||||
pytest.ini
|
|
||||||
.astrbot
|
|
||||||
|
|
||||||
uv.lock
|
# Operating System
|
||||||
|
**/.DS_Store
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# AstrBot specific
|
||||||
|
.astrbot
|
||||||
|
astrbot.lock
|
||||||
|
|
||||||
|
# Other
|
||||||
|
chroma
|
||||||
|
venv/*
|
||||||
|
pytest.ini
|
||||||
|
|||||||
@@ -6,8 +6,20 @@ ci:
|
|||||||
autoupdate_schedule: weekly
|
autoupdate_schedule: weekly
|
||||||
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.2
|
# Ruff version.
|
||||||
hooks:
|
rev: v0.14.1
|
||||||
- id: ruff
|
hooks:
|
||||||
- id: ruff-format
|
# Run the linter.
|
||||||
|
- id: ruff-check
|
||||||
|
types_or: [ python, pyi ]
|
||||||
|
args: [ --fix ]
|
||||||
|
# Run the formatter.
|
||||||
|
- id: ruff-format
|
||||||
|
types_or: [ python, pyi ]
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.21.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
args: [--py310-plus]
|
||||||
|
|||||||
@@ -12,19 +12,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
ca-certificates \
|
ca-certificates \
|
||||||
bash \
|
bash \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
|
curl \
|
||||||
|
gnupg \
|
||||||
|
git \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y curl gnupg && \
|
RUN apt-get update && apt-get install -y curl gnupg && \
|
||||||
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
||||||
apt-get install -y nodejs && \
|
apt-get install -y nodejs && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN python -m pip install uv
|
RUN python -m pip install uv
|
||||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||||
|
|
||||||
EXPOSE 6185
|
EXPOSE 6185
|
||||||
EXPOSE 6186
|
|
||||||
|
|
||||||
CMD [ "python", "main.py" ]
|
CMD ["python", "main.py"]
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
FROM python:3.10-slim
|
|
||||||
|
|
||||||
WORKDIR /AstrBot
|
|
||||||
|
|
||||||
COPY . /AstrBot/
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
gcc \
|
|
||||||
build-essential \
|
|
||||||
python3-dev \
|
|
||||||
libffi-dev \
|
|
||||||
libssl-dev \
|
|
||||||
curl \
|
|
||||||
unzip \
|
|
||||||
ca-certificates \
|
|
||||||
bash \
|
|
||||||
&& apt-get clean \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Installation of Node.js
|
|
||||||
ENV NVM_DIR="/root/.nvm"
|
|
||||||
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
|
||||||
. "$NVM_DIR/nvm.sh" && \
|
|
||||||
nvm install 22 && \
|
|
||||||
nvm use 22
|
|
||||||
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
|
|
||||||
|
|
||||||
RUN python -m pip install uv
|
|
||||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
|
||||||
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
|
|
||||||
|
|
||||||
EXPOSE 6185
|
|
||||||
EXPOSE 6186
|
|
||||||
|
|
||||||
CMD ["python", "main.py"]
|
|
||||||
146
README.md
146
README.md
@@ -4,21 +4,32 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<br>
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
<div>
|
||||||
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||

|
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||||
|
</div>
|
||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<br>
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
|
||||||
|
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
|
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
<a href="https://astrbot.app/">文档</a> |
|
<a href="https://astrbot.app/">文档</a> |
|
||||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||||
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
|
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
|
||||||
@@ -61,7 +72,7 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
|||||||
|
|
||||||
社区贡献的部署方式。
|
社区贡献的部署方式。
|
||||||
|
|
||||||
[](https://repl.it/github/Soulter/AstrBot)
|
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||||
|
|
||||||
#### Windows 一键安装器部署
|
#### Windows 一键安装器部署
|
||||||
|
|
||||||
@@ -108,82 +119,73 @@ uv run main.py
|
|||||||
|
|
||||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## 支持的消息平台
|
||||||
|
|
||||||
**官方维护**
|
**官方维护**
|
||||||
|
|
||||||
| 平台 | 支持性 |
|
- QQ (官方平台 & OneBot)
|
||||||
| -------- | ------- |
|
- Telegram
|
||||||
| QQ(官方平台) | ✔ |
|
- 企微应用 & 企微智能机器人
|
||||||
| QQ(OneBot) | ✔ |
|
- 微信客服 & 微信公众号
|
||||||
| Telegram | ✔ |
|
- 飞书
|
||||||
| 企微应用 | ✔ |
|
- 钉钉
|
||||||
| 微信客服 | ✔ |
|
- Slack
|
||||||
| 微信公众号 | ✔ |
|
- Discord
|
||||||
| 飞书 | ✔ |
|
- Satori
|
||||||
| 钉钉 | ✔ |
|
- Misskey
|
||||||
| Slack | ✔ |
|
- Whatsapp (将支持)
|
||||||
| Discord | ✔ |
|
- LINE (将支持)
|
||||||
| Satori | ✔ |
|
|
||||||
| Misskey | ✔ |
|
|
||||||
| 企微智能机器人 | 将支持 |
|
|
||||||
| Whatsapp | 将支持 |
|
|
||||||
| LINE | 将支持 |
|
|
||||||
|
|
||||||
**社区维护**
|
**社区维护**
|
||||||
|
|
||||||
| 平台 | 支持性 |
|
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||||
| -------- | ------- |
|
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||||
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
|
|
||||||
|
|
||||||
## ⚡ 提供商支持情况
|
## 支持的模型服务
|
||||||
|
|
||||||
**大模型服务**
|
**大模型服务**
|
||||||
|
|
||||||
| 名称 | 支持性 | 备注 |
|
- OpenAI 及兼容服务
|
||||||
| -------- | ------- | ------- |
|
- Anthropic
|
||||||
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
|
- Google Gemini
|
||||||
| Anthropic | ✔ | |
|
- Moonshot AI
|
||||||
| Google Gemini | ✔ | |
|
- 智谱 AI
|
||||||
| Moonshot AI | ✔ | |
|
- DeepSeek
|
||||||
| 智谱 AI | ✔ | |
|
- Ollama (本地部署)
|
||||||
| DeepSeek | ✔ | |
|
- LM Studio (本地部署)
|
||||||
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
- [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||||
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
- [302.AI](https://share.302.ai/rr1M3l)
|
||||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
|
- [小马算力](https://www.tokenpony.cn/3YPyf)
|
||||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
|
- [硅基流动](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||||
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
|
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
|
||||||
| 硅基流动 | ✔ | |
|
- ModelScope
|
||||||
| PPIO 派欧云 | ✔ | |
|
- OneAPI
|
||||||
| ModelScope | ✔ | |
|
|
||||||
| OneAPI | ✔ | |
|
**LLMOps 平台**
|
||||||
| Dify | ✔ | |
|
|
||||||
| 阿里云百炼应用 | ✔ | |
|
- Dify
|
||||||
| Coze | ✔ | |
|
- 阿里云百炼应用
|
||||||
|
- Coze
|
||||||
|
|
||||||
**语音转文本服务**
|
**语音转文本服务**
|
||||||
|
|
||||||
| 名称 | 支持性 | 备注 |
|
- OpenAI Whisper
|
||||||
| -------- | ------- | ------- |
|
- SenseVoice
|
||||||
| Whisper | ✔ | 支持 API、本地部署 |
|
|
||||||
| SenseVoice | ✔ | 本地部署 |
|
|
||||||
|
|
||||||
**文本转语音服务**
|
**文本转语音服务**
|
||||||
|
|
||||||
| 名称 | 支持性 | 备注 |
|
- OpenAI TTS
|
||||||
| -------- | ------- | ------- |
|
- Gemini TTS
|
||||||
| OpenAI TTS | ✔ | |
|
- GPT-Sovits-Inference
|
||||||
| Gemini TTS | ✔ | |
|
- GPT-Sovits
|
||||||
| GSVI | ✔ | GPT-Sovits-Inference |
|
- FishAudio
|
||||||
| GPT-SoVITs | ✔ | GPT-Sovits |
|
- Edge TTS
|
||||||
| FishAudio | ✔ | |
|
- 阿里云百炼 TTS
|
||||||
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
|
- Azure TTS
|
||||||
| 阿里云百炼 TTS | ✔ | |
|
- Minimax TTS
|
||||||
| Azure TTS | ✔ | |
|
- 火山引擎 TTS
|
||||||
| Minimax TTS | ✔ | |
|
|
||||||
| 火山引擎 TTS | ✔ | |
|
|
||||||
|
|
||||||
## ❤️ 贡献
|
## ❤️ 贡献
|
||||||
|
|
||||||
@@ -198,7 +200,7 @@ uv run main.py
|
|||||||
AstrBot 使用 `ruff` 进行代码格式化和检查。
|
AstrBot 使用 `ruff` 进行代码格式化和检查。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/Soulter/AstrBot
|
git clone https://github.com/AstrBotDevs/AstrBot
|
||||||
pip install pre-commit
|
pip install pre-commit
|
||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
@@ -217,12 +219,12 @@ pre-commit install
|
|||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
[](https://star-history.com/#soulter/astrbot&Date)
|
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
18
README_en.md
18
README_en.md
@@ -10,16 +10,16 @@ _✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
|
|||||||
|
|
||||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a>
|
||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||

|

|
||||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||||
|
|
||||||
<a href="https://astrbot.app/">Documentation</a> |
|
<a href="https://astrbot.app/">Documentation</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
|
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
|
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
|
||||||
@@ -49,7 +49,7 @@ Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app
|
|||||||
|
|
||||||
#### Replit Deployment
|
#### Replit Deployment
|
||||||
|
|
||||||
[](https://repl.it/github/Soulter/AstrBot)
|
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||||
|
|
||||||
#### CasaOS Deployment
|
#### CasaOS Deployment
|
||||||
|
|
||||||
@@ -67,8 +67,8 @@ See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
|
|||||||
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
|
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
|
||||||
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
|
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||||
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
|
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
|
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
|
||||||
| [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
|
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
|
||||||
| Feishu | ✔ | Group chats | Text, Images |
|
| Feishu | ✔ | Group chats | Text, Images |
|
||||||
| WeChat Open Platform | 🚧 | Planned | - |
|
| WeChat Open Platform | 🚧 | Planned | - |
|
||||||
| Discord | 🚧 | Planned | - |
|
| Discord | 🚧 | Planned | - |
|
||||||
@@ -157,7 +157,7 @@ _✨ Built-in Web Chat Interface ✨_
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
[](https://star-history.com/#soulter/astrbot&Date)
|
[](https://star-history.com/#AstrBotDevs/AstrBot&Date)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ _✨ Built-in Web Chat Interface ✨_
|
|||||||
|
|
||||||
<!-- ## ✨ ATRI [Beta]
|
<!-- ## ✨ ATRI [Beta]
|
||||||
|
|
||||||
Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri)
|
||||||
|
|
||||||
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
|
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
|
||||||
2. Long-term memory
|
2. Long-term memory
|
||||||
|
|||||||
@@ -10,16 +10,16 @@ _✨ 簡単に使えるマルチプラットフォーム LLM チャットボッ
|
|||||||
|
|
||||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||

|

|
||||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||||
|
|
||||||
<a href="https://astrbot.app/">ドキュメントを見る</a> |
|
<a href="https://astrbot.app/">ドキュメントを見る</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
|
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
|
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
|
||||||
@@ -50,7 +50,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
|||||||
|
|
||||||
#### Replit デプロイ
|
#### Replit デプロイ
|
||||||
|
|
||||||
[](https://repl.it/github/Soulter/AstrBot)
|
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||||
|
|
||||||
#### CasaOS デプロイ
|
#### CasaOS デプロイ
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core import html_renderer
|
from astrbot.core import html_renderer, sp
|
||||||
from astrbot.core import sp
|
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
|
||||||
from astrbot.core.star.register import register_agent as agent
|
|
||||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
|
||||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||||
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||||
|
from astrbot.core.star.register import register_agent as agent
|
||||||
|
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AstrBotConfig",
|
"AstrBotConfig",
|
||||||
"logger",
|
"BaseFunctionToolExecutor",
|
||||||
|
"FunctionTool",
|
||||||
|
"ToolSet",
|
||||||
|
"agent",
|
||||||
"html_renderer",
|
"html_renderer",
|
||||||
"llm_tool",
|
"llm_tool",
|
||||||
"agent",
|
"logger",
|
||||||
"sp",
|
"sp",
|
||||||
"ToolSet",
|
|
||||||
"FunctionTool",
|
|
||||||
"BaseFunctionToolExecutor",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageEventResult,
|
|
||||||
MessageChain,
|
|
||||||
CommandResult,
|
CommandResult,
|
||||||
EventResultType,
|
EventResultType,
|
||||||
|
MessageChain,
|
||||||
|
MessageEventResult,
|
||||||
ResultContentType,
|
ResultContentType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.platform import AstrMessageEvent
|
from astrbot.core.platform import AstrMessageEvent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MessageEventResult",
|
"AstrMessageEvent",
|
||||||
"MessageChain",
|
|
||||||
"CommandResult",
|
"CommandResult",
|
||||||
"EventResultType",
|
"EventResultType",
|
||||||
"AstrMessageEvent",
|
"MessageChain",
|
||||||
|
"MessageEventResult",
|
||||||
"ResultContentType",
|
"ResultContentType",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,51 +1,52 @@
|
|||||||
from astrbot.core.star.register import (
|
|
||||||
register_command as command,
|
|
||||||
register_command_group as command_group,
|
|
||||||
register_event_message_type as event_message_type,
|
|
||||||
register_regex as regex,
|
|
||||||
register_platform_adapter_type as platform_adapter_type,
|
|
||||||
register_permission_type as permission_type,
|
|
||||||
register_custom_filter as custom_filter,
|
|
||||||
register_on_astrbot_loaded as on_astrbot_loaded,
|
|
||||||
register_on_platform_loaded as on_platform_loaded,
|
|
||||||
register_on_llm_request as on_llm_request,
|
|
||||||
register_on_llm_response as on_llm_response,
|
|
||||||
register_llm_tool as llm_tool,
|
|
||||||
register_on_decorating_result as on_decorating_result,
|
|
||||||
register_after_message_sent as after_message_sent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from astrbot.core.star.filter.event_message_type import (
|
|
||||||
EventMessageTypeFilter,
|
|
||||||
EventMessageType,
|
|
||||||
)
|
|
||||||
from astrbot.core.star.filter.platform_adapter_type import (
|
|
||||||
PlatformAdapterTypeFilter,
|
|
||||||
PlatformAdapterType,
|
|
||||||
)
|
|
||||||
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
|
|
||||||
from astrbot.core.star.filter.custom_filter import CustomFilter
|
from astrbot.core.star.filter.custom_filter import CustomFilter
|
||||||
|
from astrbot.core.star.filter.event_message_type import (
|
||||||
|
EventMessageType,
|
||||||
|
EventMessageTypeFilter,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
|
||||||
|
from astrbot.core.star.filter.platform_adapter_type import (
|
||||||
|
PlatformAdapterType,
|
||||||
|
PlatformAdapterTypeFilter,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.register import register_after_message_sent as after_message_sent
|
||||||
|
from astrbot.core.star.register import register_command as command
|
||||||
|
from astrbot.core.star.register import register_command_group as command_group
|
||||||
|
from astrbot.core.star.register import register_custom_filter as custom_filter
|
||||||
|
from astrbot.core.star.register import register_event_message_type as event_message_type
|
||||||
|
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||||
|
from astrbot.core.star.register import register_on_astrbot_loaded as on_astrbot_loaded
|
||||||
|
from astrbot.core.star.register import (
|
||||||
|
register_on_decorating_result as on_decorating_result,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.register import register_on_llm_request as on_llm_request
|
||||||
|
from astrbot.core.star.register import register_on_llm_response as on_llm_response
|
||||||
|
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
|
||||||
|
from astrbot.core.star.register import register_permission_type as permission_type
|
||||||
|
from astrbot.core.star.register import (
|
||||||
|
register_platform_adapter_type as platform_adapter_type,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.register import register_regex as regex
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"CustomFilter",
|
||||||
|
"EventMessageType",
|
||||||
|
"EventMessageTypeFilter",
|
||||||
|
"PermissionType",
|
||||||
|
"PermissionTypeFilter",
|
||||||
|
"PlatformAdapterType",
|
||||||
|
"PlatformAdapterTypeFilter",
|
||||||
|
"after_message_sent",
|
||||||
"command",
|
"command",
|
||||||
"command_group",
|
"command_group",
|
||||||
"event_message_type",
|
|
||||||
"regex",
|
|
||||||
"platform_adapter_type",
|
|
||||||
"permission_type",
|
|
||||||
"EventMessageTypeFilter",
|
|
||||||
"EventMessageType",
|
|
||||||
"PlatformAdapterTypeFilter",
|
|
||||||
"PlatformAdapterType",
|
|
||||||
"PermissionTypeFilter",
|
|
||||||
"CustomFilter",
|
|
||||||
"custom_filter",
|
"custom_filter",
|
||||||
"PermissionType",
|
"event_message_type",
|
||||||
"on_astrbot_loaded",
|
|
||||||
"on_platform_loaded",
|
|
||||||
"on_llm_request",
|
|
||||||
"llm_tool",
|
"llm_tool",
|
||||||
|
"on_astrbot_loaded",
|
||||||
"on_decorating_result",
|
"on_decorating_result",
|
||||||
"after_message_sent",
|
"on_llm_request",
|
||||||
"on_llm_response",
|
"on_llm_response",
|
||||||
|
"on_platform_loaded",
|
||||||
|
"permission_type",
|
||||||
|
"platform_adapter_type",
|
||||||
|
"regex",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,23 +1,22 @@
|
|||||||
|
from astrbot.core.message.components import *
|
||||||
from astrbot.core.platform import (
|
from astrbot.core.platform import (
|
||||||
AstrMessageEvent,
|
|
||||||
Platform,
|
|
||||||
AstrBotMessage,
|
AstrBotMessage,
|
||||||
|
AstrMessageEvent,
|
||||||
|
Group,
|
||||||
MessageMember,
|
MessageMember,
|
||||||
MessageType,
|
MessageType,
|
||||||
|
Platform,
|
||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
Group,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.platform.register import register_platform_adapter
|
from astrbot.core.platform.register import register_platform_adapter
|
||||||
from astrbot.core.message.components import *
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AstrMessageEvent",
|
|
||||||
"Platform",
|
|
||||||
"AstrBotMessage",
|
"AstrBotMessage",
|
||||||
|
"AstrMessageEvent",
|
||||||
|
"Group",
|
||||||
"MessageMember",
|
"MessageMember",
|
||||||
"MessageType",
|
"MessageType",
|
||||||
|
"Platform",
|
||||||
"PlatformMetadata",
|
"PlatformMetadata",
|
||||||
"register_platform_adapter",
|
"register_platform_adapter",
|
||||||
"Group",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
from astrbot.core.provider import Personality, Provider, STTProvider
|
||||||
from astrbot.core.provider.entities import (
|
from astrbot.core.provider.entities import (
|
||||||
|
LLMResponse,
|
||||||
|
ProviderMetaData,
|
||||||
ProviderRequest,
|
ProviderRequest,
|
||||||
ProviderType,
|
ProviderType,
|
||||||
ProviderMetaData,
|
|
||||||
LLMResponse,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Provider",
|
"LLMResponse",
|
||||||
"STTProvider",
|
|
||||||
"Personality",
|
"Personality",
|
||||||
|
"Provider",
|
||||||
|
"ProviderMetaData",
|
||||||
"ProviderRequest",
|
"ProviderRequest",
|
||||||
"ProviderType",
|
"ProviderType",
|
||||||
"ProviderMetaData",
|
"STTProvider",
|
||||||
"LLMResponse",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
|
from astrbot.core.star import Context, Star, StarTools
|
||||||
|
from astrbot.core.star.config import *
|
||||||
from astrbot.core.star.register import (
|
from astrbot.core.star.register import (
|
||||||
register_star as register, # 注册插件(Star)
|
register_star as register, # 注册插件(Star)
|
||||||
)
|
)
|
||||||
|
|
||||||
from astrbot.core.star import Context, Star, StarTools
|
__all__ = ["Context", "Star", "StarTools", "register"]
|
||||||
from astrbot.core.star.config import *
|
|
||||||
|
|
||||||
__all__ = ["register", "Context", "Star", "StarTools"]
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from astrbot.core.utils.session_waiter import (
|
from astrbot.core.utils.session_waiter import (
|
||||||
SessionWaiter,
|
|
||||||
SessionController,
|
SessionController,
|
||||||
|
SessionWaiter,
|
||||||
session_waiter,
|
session_waiter,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]
|
__all__ = ["SessionController", "SessionWaiter", "session_waiter"]
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""
|
"""AstrBot CLI入口"""
|
||||||
AstrBot CLI入口
|
|
||||||
"""
|
import sys
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import sys
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .commands import init, run, plug, conf
|
from .commands import conf, init, plug, run
|
||||||
|
|
||||||
logo_tmpl = r"""
|
logo_tmpl = r"""
|
||||||
___ _______.___________..______ .______ ______ .___________.
|
___ _______.___________..______ .______ ______ .___________.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from .cmd_init import init
|
|
||||||
from .cmd_run import run
|
|
||||||
from .cmd_plug import plug
|
|
||||||
from .cmd_conf import conf
|
from .cmd_conf import conf
|
||||||
|
from .cmd_init import init
|
||||||
|
from .cmd_plug import plug
|
||||||
|
from .cmd_run import run
|
||||||
|
|
||||||
__all__ = ["init", "run", "plug", "conf"]
|
__all__ = ["conf", "init", "plug", "run"]
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import json
|
|
||||||
import click
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import zoneinfo
|
import zoneinfo
|
||||||
from typing import Any, Callable
|
from collections.abc import Callable
|
||||||
from ..utils import get_astrbot_root, check_astrbot_root
|
from typing import Any
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from ..utils import check_astrbot_root, get_astrbot_root
|
||||||
|
|
||||||
|
|
||||||
def _validate_log_level(value: str) -> str:
|
def _validate_log_level(value: str) -> str:
|
||||||
@@ -11,7 +14,7 @@ def _validate_log_level(value: str) -> str:
|
|||||||
value = value.upper()
|
value = value.upper()
|
||||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
|
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
|
||||||
)
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -73,7 +76,7 @@ def _load_config() -> dict[str, Any]:
|
|||||||
root = get_astrbot_root()
|
root = get_astrbot_root()
|
||||||
if not check_astrbot_root(root):
|
if not check_astrbot_root(root):
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||||
)
|
)
|
||||||
|
|
||||||
config_path = root / "data" / "cmd_config.json"
|
config_path = root / "data" / "cmd_config.json"
|
||||||
@@ -88,7 +91,7 @@ def _load_config() -> dict[str, Any]:
|
|||||||
try:
|
try:
|
||||||
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise click.ClickException(f"配置文件解析失败: {str(e)}")
|
raise click.ClickException(f"配置文件解析失败: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
def _save_config(config: dict[str, Any]) -> None:
|
def _save_config(config: dict[str, Any]) -> None:
|
||||||
@@ -96,7 +99,8 @@ def _save_config(config: dict[str, Any]) -> None:
|
|||||||
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||||
|
|
||||||
config_path.write_text(
|
config_path.write_text(
|
||||||
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
json.dumps(config, ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8-sig",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -108,7 +112,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
|||||||
obj[part] = {}
|
obj[part] = {}
|
||||||
elif not isinstance(obj[part], dict):
|
elif not isinstance(obj[part], dict):
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
|
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典",
|
||||||
)
|
)
|
||||||
obj = obj[part]
|
obj = obj[part]
|
||||||
obj[parts[-1]] = value
|
obj[parts[-1]] = value
|
||||||
@@ -140,7 +144,6 @@ def conf():
|
|||||||
|
|
||||||
- callback_api_base: 回调接口基址
|
- callback_api_base: 回调接口基址
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@conf.command(name="set")
|
@conf.command(name="set")
|
||||||
@@ -148,7 +151,7 @@ def conf():
|
|||||||
@click.argument("value")
|
@click.argument("value")
|
||||||
def set_config(key: str, value: str):
|
def set_config(key: str, value: str):
|
||||||
"""设置配置项的值"""
|
"""设置配置项的值"""
|
||||||
if key not in CONFIG_VALIDATORS.keys():
|
if key not in CONFIG_VALIDATORS:
|
||||||
raise click.ClickException(f"不支持的配置项: {key}")
|
raise click.ClickException(f"不支持的配置项: {key}")
|
||||||
|
|
||||||
config = _load_config()
|
config = _load_config()
|
||||||
@@ -170,17 +173,17 @@ def set_config(key: str, value: str):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise click.ClickException(f"未知的配置项: {key}")
|
raise click.ClickException(f"未知的配置项: {key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise click.UsageError(f"设置配置失败: {str(e)}")
|
raise click.UsageError(f"设置配置失败: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
@conf.command(name="get")
|
@conf.command(name="get")
|
||||||
@click.argument("key", required=False)
|
@click.argument("key", required=False)
|
||||||
def get_config(key: str = None):
|
def get_config(key: str | None = None):
|
||||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||||
config = _load_config()
|
config = _load_config()
|
||||||
|
|
||||||
if key:
|
if key:
|
||||||
if key not in CONFIG_VALIDATORS.keys():
|
if key not in CONFIG_VALIDATORS:
|
||||||
raise click.ClickException(f"不支持的配置项: {key}")
|
raise click.ClickException(f"不支持的配置项: {key}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -191,10 +194,10 @@ def get_config(key: str = None):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise click.ClickException(f"未知的配置项: {key}")
|
raise click.ClickException(f"未知的配置项: {key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise click.UsageError(f"获取配置失败: {str(e)}")
|
raise click.UsageError(f"获取配置失败: {e!s}")
|
||||||
else:
|
else:
|
||||||
click.echo("当前配置:")
|
click.echo("当前配置:")
|
||||||
for key in CONFIG_VALIDATORS.keys():
|
for key in CONFIG_VALIDATORS:
|
||||||
try:
|
try:
|
||||||
value = (
|
value = (
|
||||||
"********"
|
"********"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from filelock import FileLock, Timeout
|
from filelock import FileLock, Timeout
|
||||||
@@ -6,14 +7,14 @@ from filelock import FileLock, Timeout
|
|||||||
from ..utils import check_dashboard, get_astrbot_root
|
from ..utils import check_dashboard, get_astrbot_root
|
||||||
|
|
||||||
|
|
||||||
async def initialize_astrbot(astrbot_root) -> None:
|
async def initialize_astrbot(astrbot_root: Path) -> None:
|
||||||
"""执行 AstrBot 初始化逻辑"""
|
"""执行 AstrBot 初始化逻辑"""
|
||||||
dot_astrbot = astrbot_root / ".astrbot"
|
dot_astrbot = astrbot_root / ".astrbot"
|
||||||
|
|
||||||
if not dot_astrbot.exists():
|
if not dot_astrbot.exists():
|
||||||
click.echo(f"Current Directory: {astrbot_root}")
|
click.echo(f"Current Directory: {astrbot_root}")
|
||||||
click.echo(
|
click.echo(
|
||||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。",
|
||||||
)
|
)
|
||||||
if click.confirm(
|
if click.confirm(
|
||||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||||
|
|||||||
@@ -1,31 +1,29 @@
|
|||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import shutil
|
|
||||||
|
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
get_git_repo,
|
|
||||||
build_plug_list,
|
|
||||||
manage_plugin,
|
|
||||||
PluginStatus,
|
PluginStatus,
|
||||||
|
build_plug_list,
|
||||||
check_astrbot_root,
|
check_astrbot_root,
|
||||||
get_astrbot_root,
|
get_astrbot_root,
|
||||||
|
get_git_repo,
|
||||||
|
manage_plugin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def plug():
|
def plug():
|
||||||
"""插件管理"""
|
"""插件管理"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _get_data_path() -> Path:
|
def _get_data_path() -> Path:
|
||||||
base = get_astrbot_root()
|
base = get_astrbot_root()
|
||||||
if not check_astrbot_root(base):
|
if not check_astrbot_root(base):
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||||
)
|
)
|
||||||
return (base / "data").resolve()
|
return (base / "data").resolve()
|
||||||
|
|
||||||
@@ -41,7 +39,7 @@ def display_plugins(plugins, title=None, color=None):
|
|||||||
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
||||||
click.echo(
|
click.echo(
|
||||||
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
||||||
f"{p['author']:<15} {desc:<30}"
|
f"{p['author']:<15} {desc:<30}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,7 +76,7 @@ def new(name: str):
|
|||||||
f"desc: {desc}\n"
|
f"desc: {desc}\n"
|
||||||
f"version: {version}\n"
|
f"version: {version}\n"
|
||||||
f"author: {author}\n"
|
f"author: {author}\n"
|
||||||
f"repo: {repo}\n"
|
f"repo: {repo}\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 重写 README.md
|
# 重写 README.md
|
||||||
@@ -86,7 +84,7 @@ def new(name: str):
|
|||||||
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||||
|
|
||||||
# 重写 main.py
|
# 重写 main.py
|
||||||
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
|
with open(plug_path / "main.py", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
new_content = content.replace(
|
new_content = content.replace(
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import asyncio
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from filelock import FileLock, Timeout
|
from filelock import FileLock, Timeout
|
||||||
|
|
||||||
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
|
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
|
||||||
|
|
||||||
|
|
||||||
async def run_astrbot(astrbot_root: Path):
|
async def run_astrbot(astrbot_root: Path):
|
||||||
"""运行 AstrBot"""
|
"""运行 AstrBot"""
|
||||||
from astrbot.core import logger, LogManager, LogBroker, db_helper
|
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||||
from astrbot.core.initial_loader import InitialLoader
|
from astrbot.core.initial_loader import InitialLoader
|
||||||
|
|
||||||
await check_dashboard(astrbot_root / "data")
|
await check_dashboard(astrbot_root / "data")
|
||||||
@@ -38,7 +37,7 @@ def run(reload: bool, port: str) -> None:
|
|||||||
|
|
||||||
if not check_astrbot_root(astrbot_root):
|
if not check_astrbot_root(astrbot_root):
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
from .basic import (
|
from .basic import (
|
||||||
get_astrbot_root,
|
|
||||||
check_astrbot_root,
|
check_astrbot_root,
|
||||||
check_dashboard,
|
check_dashboard,
|
||||||
|
get_astrbot_root,
|
||||||
)
|
)
|
||||||
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
|
from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin
|
||||||
from .version_comparator import VersionComparator
|
from .version_comparator import VersionComparator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_astrbot_root",
|
"PluginStatus",
|
||||||
|
"VersionComparator",
|
||||||
|
"build_plug_list",
|
||||||
"check_astrbot_root",
|
"check_astrbot_root",
|
||||||
"check_dashboard",
|
"check_dashboard",
|
||||||
|
"get_astrbot_root",
|
||||||
"get_git_repo",
|
"get_git_repo",
|
||||||
"manage_plugin",
|
"manage_plugin",
|
||||||
"build_plug_list",
|
|
||||||
"VersionComparator",
|
|
||||||
"PluginStatus",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ def get_astrbot_root() -> Path:
|
|||||||
|
|
||||||
async def check_dashboard(astrbot_root: Path) -> None:
|
async def check_dashboard(astrbot_root: Path) -> None:
|
||||||
"""检查是否安装了dashboard"""
|
"""检查是否安装了dashboard"""
|
||||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
|
||||||
from astrbot.core.config.default import VERSION
|
from astrbot.core.config.default import VERSION
|
||||||
|
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||||
|
|
||||||
from .version_comparator import VersionComparator
|
from .version_comparator import VersionComparator
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -48,19 +49,18 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
|||||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||||
click.echo("管理面板已是最新版本")
|
click.echo("管理面板已是最新版本")
|
||||||
return
|
return
|
||||||
else:
|
try:
|
||||||
try:
|
version = dashboard_version.split("v")[1]
|
||||||
version = dashboard_version.split("v")[1]
|
click.echo(f"管理面板版本: {version}")
|
||||||
click.echo(f"管理面板版本: {version}")
|
await download_dashboard(
|
||||||
await download_dashboard(
|
path="data/dashboard.zip",
|
||||||
path="data/dashboard.zip",
|
extract_path=str(astrbot_root),
|
||||||
extract_path=str(astrbot_root),
|
version=f"v{VERSION}",
|
||||||
version=f"v{VERSION}",
|
latest=False,
|
||||||
latest=False,
|
)
|
||||||
)
|
except Exception as e:
|
||||||
except Exception as e:
|
click.echo(f"下载管理面板失败: {e}")
|
||||||
click.echo(f"下载管理面板失败: {e}")
|
return
|
||||||
return
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
click.echo("初始化管理面板目录...")
|
click.echo("初始化管理面板目录...")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import httpx
|
|
||||||
import yaml
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
import httpx
|
||||||
|
import yaml
|
||||||
|
|
||||||
from .version_comparator import VersionComparator
|
from .version_comparator import VersionComparator
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +32,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
|||||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||||
try:
|
try:
|
||||||
with httpx.Client(
|
with httpx.Client(
|
||||||
proxy=proxy if proxy else None, follow_redirects=True
|
proxy=proxy if proxy else None,
|
||||||
|
follow_redirects=True,
|
||||||
) as client:
|
) as client:
|
||||||
resp = client.get(release_url)
|
resp = client.get(release_url)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
@@ -55,7 +56,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
|||||||
|
|
||||||
# 下载并解压
|
# 下载并解压
|
||||||
with httpx.Client(
|
with httpx.Client(
|
||||||
proxy=proxy if proxy else None, follow_redirects=True
|
proxy=proxy if proxy else None,
|
||||||
|
follow_redirects=True,
|
||||||
) as client:
|
) as client:
|
||||||
resp = client.get(download_url)
|
resp = client.get(download_url)
|
||||||
if (
|
if (
|
||||||
@@ -89,6 +91,7 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 包含元数据的字典,如果读取失败则返回空字典
|
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||||
|
|
||||||
"""
|
"""
|
||||||
yaml_path = plugin_dir / "metadata.yaml"
|
yaml_path = plugin_dir / "metadata.yaml"
|
||||||
if yaml_path.exists():
|
if yaml_path.exists():
|
||||||
@@ -107,6 +110,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: 包含插件信息的字典列表
|
list: 包含插件信息的字典列表
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# 获取本地插件信息
|
# 获取本地插件信息
|
||||||
result = []
|
result = []
|
||||||
@@ -133,7 +137,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
"repo": str(metadata.get("repo", "")),
|
"repo": str(metadata.get("repo", "")),
|
||||||
"status": PluginStatus.INSTALLED,
|
"status": PluginStatus.INSTALLED,
|
||||||
"local_path": str(plugin_dir),
|
"local_path": str(plugin_dir),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取在线插件列表
|
# 获取在线插件列表
|
||||||
@@ -153,7 +157,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
"repo": str(plugin_info.get("repo", "")),
|
"repo": str(plugin_info.get("repo", "")),
|
||||||
"status": PluginStatus.NOT_INSTALLED,
|
"status": PluginStatus.NOT_INSTALLED,
|
||||||
"local_path": None,
|
"local_path": None,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||||
@@ -168,7 +172,8 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
VersionComparator.compare_version(
|
VersionComparator.compare_version(
|
||||||
local_plugin["version"], online_plugin["version"]
|
local_plugin["version"],
|
||||||
|
online_plugin["version"],
|
||||||
)
|
)
|
||||||
< 0
|
< 0
|
||||||
):
|
):
|
||||||
@@ -186,7 +191,10 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
|
|
||||||
|
|
||||||
def manage_plugin(
|
def manage_plugin(
|
||||||
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
|
plugin: dict,
|
||||||
|
plugins_dir: Path,
|
||||||
|
is_update: bool = False,
|
||||||
|
proxy: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""安装或更新插件
|
"""安装或更新插件
|
||||||
|
|
||||||
@@ -195,6 +203,7 @@ def manage_plugin(
|
|||||||
plugins_dir (Path): 插件目录
|
plugins_dir (Path): 插件目录
|
||||||
is_update (bool, optional): 是否为更新操作. 默认为 False
|
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||||
proxy (str, optional): 代理服务器地址
|
proxy (str, optional): 代理服务器地址
|
||||||
|
|
||||||
"""
|
"""
|
||||||
plugin_name = plugin["name"]
|
plugin_name = plugin["name"]
|
||||||
repo_url = plugin["repo"]
|
repo_url = plugin["repo"]
|
||||||
@@ -212,26 +221,26 @@ def manage_plugin(
|
|||||||
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||||
|
|
||||||
# 备份现有插件
|
# 备份现有插件
|
||||||
if is_update and backup_path.exists():
|
if is_update and backup_path is not None and backup_path.exists():
|
||||||
shutil.rmtree(backup_path)
|
shutil.rmtree(backup_path)
|
||||||
if is_update:
|
if is_update and backup_path is not None:
|
||||||
shutil.copytree(target_path, backup_path)
|
shutil.copytree(target_path, backup_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
click.echo(
|
click.echo(
|
||||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...",
|
||||||
)
|
)
|
||||||
get_git_repo(repo_url, target_path, proxy)
|
get_git_repo(repo_url, target_path, proxy)
|
||||||
|
|
||||||
# 更新成功,删除备份
|
# 更新成功,删除备份
|
||||||
if is_update and backup_path.exists():
|
if is_update and backup_path is not None and backup_path.exists():
|
||||||
shutil.rmtree(backup_path)
|
shutil.rmtree(backup_path)
|
||||||
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if target_path.exists():
|
if target_path.exists():
|
||||||
shutil.rmtree(target_path, ignore_errors=True)
|
shutil.rmtree(target_path, ignore_errors=True)
|
||||||
if is_update and backup_path.exists():
|
if is_update and backup_path is not None and backup_path.exists():
|
||||||
shutil.move(backup_path, target_path)
|
shutil.move(backup_path, target_path)
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""拷贝自 astrbot.core.utils.version_comparator"""
|
||||||
拷贝自 astrbot.core.utils.version_comparator
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -42,15 +40,15 @@ class VersionComparator:
|
|||||||
for i in range(length):
|
for i in range(length):
|
||||||
if v1_parts[i] > v2_parts[i]:
|
if v1_parts[i] > v2_parts[i]:
|
||||||
return 1
|
return 1
|
||||||
elif v1_parts[i] < v2_parts[i]:
|
if v1_parts[i] < v2_parts[i]:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
# 比较预发布标签
|
# 比较预发布标签
|
||||||
if v1_prerelease is None and v2_prerelease is not None:
|
if v1_prerelease is None and v2_prerelease is not None:
|
||||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||||
elif v1_prerelease is not None and v2_prerelease is None:
|
if v1_prerelease is not None and v2_prerelease is None:
|
||||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||||
elif v1_prerelease is not None and v2_prerelease is not None:
|
if v1_prerelease is not None and v2_prerelease is not None:
|
||||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||||
for i in range(len_pre):
|
for i in range(len_pre):
|
||||||
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||||
@@ -58,21 +56,21 @@ class VersionComparator:
|
|||||||
|
|
||||||
if p1 is None and p2 is not None:
|
if p1 is None and p2 is not None:
|
||||||
return -1
|
return -1
|
||||||
elif p1 is not None and p2 is None:
|
if p1 is not None and p2 is None:
|
||||||
return 1
|
return 1
|
||||||
elif isinstance(p1, int) and isinstance(p2, str):
|
if isinstance(p1, int) and isinstance(p2, str):
|
||||||
return -1
|
return -1
|
||||||
elif isinstance(p1, str) and isinstance(p2, int):
|
if isinstance(p1, str) and isinstance(p2, int):
|
||||||
return 1
|
return 1
|
||||||
elif isinstance(p1, int) and isinstance(p2, int):
|
if isinstance(p1, int) and isinstance(p2, int):
|
||||||
if p1 > p2:
|
if p1 > p2:
|
||||||
return 1
|
return 1
|
||||||
elif p1 < p2:
|
if p1 < p2:
|
||||||
return -1
|
return -1
|
||||||
elif isinstance(p1, str) and isinstance(p2, str):
|
elif isinstance(p1, str) and isinstance(p2, str):
|
||||||
if p1 > p2:
|
if p1 > p2:
|
||||||
return 1
|
return 1
|
||||||
elif p1 < p2:
|
if p1 < p2:
|
||||||
return -1
|
return -1
|
||||||
return 0 # 预发布标签完全相同
|
return 0 # 预发布标签完全相同
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
from .log import LogManager, LogBroker # noqa
|
|
||||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
|
||||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
|
||||||
from astrbot.core.utils.pip_installer import PipInstaller
|
|
||||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
|
||||||
from astrbot.core.config.default import DB_PATH
|
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
from astrbot.core.config.default import DB_PATH
|
||||||
|
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||||
from astrbot.core.file_token_service import FileTokenService
|
from astrbot.core.file_token_service import FileTokenService
|
||||||
|
from astrbot.core.utils.pip_installer import PipInstaller
|
||||||
|
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||||
|
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||||
|
|
||||||
|
from .log import LogBroker, LogManager # noqa
|
||||||
from .utils.astrbot_path import get_astrbot_data_path
|
from .utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
# 初始化数据存储文件夹
|
# 初始化数据存储文件夹
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from .tool import FunctionTool
|
|
||||||
from typing import Generic
|
from typing import Generic
|
||||||
from .run_context import TContext
|
|
||||||
from .hooks import BaseAgentRunHooks
|
from .hooks import BaseAgentRunHooks
|
||||||
|
from .run_context import TContext
|
||||||
|
from .tool import FunctionTool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
from typing import Generic
|
from typing import Generic
|
||||||
from .tool import FunctionTool
|
|
||||||
from .agent import Agent
|
from .agent import Agent
|
||||||
from .run_context import TContext
|
from .run_context import TContext
|
||||||
|
from .tool import FunctionTool
|
||||||
|
|
||||||
|
|
||||||
class HandoffTool(FunctionTool, Generic[TContext]):
|
class HandoffTool(FunctionTool, Generic[TContext]):
|
||||||
"""Handoff tool for delegating tasks to another agent."""
|
"""Handoff tool for delegating tasks to another agent."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
|
self,
|
||||||
|
agent: Agent[TContext],
|
||||||
|
parameters: dict | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import mcp
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from .run_context import ContextWrapper, TContext
|
|
||||||
from typing import Generic
|
from typing import Generic
|
||||||
from astrbot.core.provider.entities import LLMResponse
|
|
||||||
|
import mcp
|
||||||
|
|
||||||
from astrbot.core.agent.tool import FunctionTool
|
from astrbot.core.agent.tool import FunctionTool
|
||||||
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
|
|
||||||
|
from .run_context import ContextWrapper, TContext
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseAgentRunHooks(Generic[TContext]):
|
class BaseAgentRunHooks(Generic[TContext]):
|
||||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
@@ -23,5 +24,7 @@ class BaseAgentRunHooks(Generic[TContext]):
|
|||||||
tool_result: mcp.types.CallToolResult | None,
|
tool_result: mcp.types.CallToolResult | None,
|
||||||
): ...
|
): ...
|
||||||
async def on_agent_done(
|
async def on_agent_done(
|
||||||
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
|
self,
|
||||||
|
run_context: ContextWrapper[TContext],
|
||||||
|
llm_response: LLMResponse,
|
||||||
): ...
|
): ...
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Optional
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Generic
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
from astrbot.core.agent.run_context import ContextWrapper
|
||||||
from astrbot.core.utils.log_pipe import LogPipe
|
from astrbot.core.utils.log_pipe import LogPipe
|
||||||
|
|
||||||
|
from .run_context import TContext
|
||||||
|
from .tool import FunctionTool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import mcp
|
import mcp
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
@@ -16,13 +21,13 @@ try:
|
|||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_config(config: dict) -> dict:
|
def _prepare_config(config: dict) -> dict:
|
||||||
"""准备配置,处理嵌套格式"""
|
"""准备配置,处理嵌套格式"""
|
||||||
if "mcpServers" in config and config["mcpServers"]:
|
if config.get("mcpServers"):
|
||||||
first_key = next(iter(config["mcpServers"]))
|
first_key = next(iter(config["mcpServers"]))
|
||||||
config = config["mcpServers"][first_key]
|
config = config["mcpServers"][first_key]
|
||||||
config.pop("active", None)
|
config.pop("active", None)
|
||||||
@@ -71,8 +76,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
return True, ""
|
return True, ""
|
||||||
else:
|
return False, f"HTTP {response.status}: {response.reason}"
|
||||||
return False, f"HTTP {response.status}: {response.reason}"
|
|
||||||
else:
|
else:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url,
|
url,
|
||||||
@@ -84,8 +88,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
return True, ""
|
return True, ""
|
||||||
else:
|
return False, f"HTTP {response.status}: {response.reason}"
|
||||||
return False, f"HTTP {response.status}: {response.reason}"
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
return False, f"连接超时: {timeout}秒"
|
return False, f"连接超时: {timeout}秒"
|
||||||
@@ -96,7 +99,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|||||||
class MCPClient:
|
class MCPClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Initialize session and client objects
|
# Initialize session and client objects
|
||||||
self.session: Optional[mcp.ClientSession] = None
|
self.session: mcp.ClientSession | None = None
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
self.name: str | None = None
|
self.name: str | None = None
|
||||||
@@ -115,6 +118,7 @@ class MCPClient:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cfg = _prepare_config(mcp_server_config.copy())
|
cfg = _prepare_config(mcp_server_config.copy())
|
||||||
|
|
||||||
@@ -144,7 +148,7 @@ class MCPClient:
|
|||||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||||
)
|
)
|
||||||
streams = await self.exit_stack.enter_async_context(
|
streams = await self.exit_stack.enter_async_context(
|
||||||
self._streams_context
|
self._streams_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new client session
|
# Create a new client session
|
||||||
@@ -154,12 +158,12 @@ class MCPClient:
|
|||||||
*streams,
|
*streams,
|
||||||
read_timeout_seconds=read_timeout,
|
read_timeout_seconds=read_timeout,
|
||||||
logging_callback=logging_callback, # type: ignore
|
logging_callback=logging_callback, # type: ignore
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||||
sse_read_timeout = timedelta(
|
sse_read_timeout = timedelta(
|
||||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
seconds=cfg.get("sse_read_timeout", 60 * 5),
|
||||||
)
|
)
|
||||||
self._streams_context = streamablehttp_client(
|
self._streams_context = streamablehttp_client(
|
||||||
url=cfg["url"],
|
url=cfg["url"],
|
||||||
@@ -169,7 +173,7 @@ class MCPClient:
|
|||||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||||
)
|
)
|
||||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||||
self._streams_context
|
self._streams_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new client session
|
# Create a new client session
|
||||||
@@ -180,7 +184,7 @@ class MCPClient:
|
|||||||
write_stream=write_s,
|
write_stream=write_s,
|
||||||
read_timeout_seconds=read_timeout,
|
read_timeout_seconds=read_timeout,
|
||||||
logging_callback=logging_callback, # type: ignore
|
logging_callback=logging_callback, # type: ignore
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -206,7 +210,7 @@ class MCPClient:
|
|||||||
|
|
||||||
# Create a new client session
|
# Create a new client session
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
mcp.ClientSession(*stdio_transport)
|
mcp.ClientSession(*stdio_transport),
|
||||||
)
|
)
|
||||||
await self.session.initialize()
|
await self.session.initialize()
|
||||||
|
|
||||||
@@ -222,3 +226,34 @@ class MCPClient:
|
|||||||
"""Clean up resources"""
|
"""Clean up resources"""
|
||||||
await self.exit_stack.aclose()
|
await self.exit_stack.aclose()
|
||||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||||
|
|
||||||
|
|
||||||
|
class MCPTool(FunctionTool, Generic[TContext]):
|
||||||
|
"""A function tool that calls an MCP service."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=mcp_tool.name,
|
||||||
|
description=mcp_tool.description or "",
|
||||||
|
parameters=mcp_tool.inputSchema,
|
||||||
|
)
|
||||||
|
self.mcp_tool = mcp_tool
|
||||||
|
self.mcp_client = mcp_client
|
||||||
|
self.mcp_server_name = mcp_server_name
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self, context: ContextWrapper[TContext], **kwargs
|
||||||
|
) -> mcp.types.CallToolResult:
|
||||||
|
session = self.mcp_client.session
|
||||||
|
if not session:
|
||||||
|
raise ValueError("MCP session is not available for MCP function tools.")
|
||||||
|
res = await session.call_tool(
|
||||||
|
name=self.mcp_tool.name,
|
||||||
|
arguments=kwargs,
|
||||||
|
read_timeout_seconds=timedelta(
|
||||||
|
seconds=context.tool_call_timeout,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|||||||
168
astrbot/core/agent/message.py
Normal file
168
astrbot/core/agent/message.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
|
||||||
|
# License: Apache License 2.0
|
||||||
|
|
||||||
|
from typing import Any, ClassVar, Literal, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
|
||||||
|
class ContentPart(BaseModel):
|
||||||
|
"""A part of the content in a message."""
|
||||||
|
|
||||||
|
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
||||||
|
|
||||||
|
type: str
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
|
||||||
|
|
||||||
|
type_value = getattr(cls, "type", None)
|
||||||
|
if type_value is None or not isinstance(type_value, str):
|
||||||
|
raise ValueError(invalid_subclass_error_msg)
|
||||||
|
|
||||||
|
cls.__content_part_registry[type_value] = cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
# If we're dealing with the base ContentPart class, use custom validation
|
||||||
|
if cls.__name__ == "ContentPart":
|
||||||
|
|
||||||
|
def validate_content_part(value: Any) -> Any:
|
||||||
|
# if it's already an instance of a ContentPart subclass, return it
|
||||||
|
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
|
||||||
|
return value
|
||||||
|
|
||||||
|
# if it's a dict with a type field, dispatch to the appropriate subclass
|
||||||
|
if isinstance(value, dict) and "type" in value:
|
||||||
|
type_value: Any | None = cast(dict[str, Any], value).get("type")
|
||||||
|
if not isinstance(type_value, str):
|
||||||
|
raise ValueError(f"Cannot validate {value} as ContentPart")
|
||||||
|
target_class = cls.__content_part_registry[type_value]
|
||||||
|
return target_class.model_validate(value)
|
||||||
|
|
||||||
|
raise ValueError(f"Cannot validate {value} as ContentPart")
|
||||||
|
|
||||||
|
return core_schema.no_info_plain_validator_function(validate_content_part)
|
||||||
|
|
||||||
|
# for subclasses, use the default schema
|
||||||
|
return handler(source_type)
|
||||||
|
|
||||||
|
|
||||||
|
class TextPart(ContentPart):
|
||||||
|
"""
|
||||||
|
>>> TextPart(text="Hello, world!").model_dump()
|
||||||
|
{'type': 'text', 'text': 'Hello, world!'}
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageURLPart(ContentPart):
|
||||||
|
"""
|
||||||
|
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
||||||
|
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ImageURL(BaseModel):
|
||||||
|
url: str
|
||||||
|
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
|
||||||
|
id: str | None = None
|
||||||
|
"""The ID of the image, to allow LLMs to distinguish different images."""
|
||||||
|
|
||||||
|
type: str = "image_url"
|
||||||
|
image_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class AudioURLPart(ContentPart):
|
||||||
|
"""
|
||||||
|
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
|
||||||
|
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AudioURL(BaseModel):
|
||||||
|
url: str
|
||||||
|
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
|
||||||
|
id: str | None = None
|
||||||
|
"""The ID of the audio, to allow LLMs to distinguish different audios."""
|
||||||
|
|
||||||
|
type: str = "audio_url"
|
||||||
|
audio_url: AudioURL
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
"""
|
||||||
|
A tool call requested by the assistant.
|
||||||
|
|
||||||
|
>>> ToolCall(
|
||||||
|
... id="123",
|
||||||
|
... function=ToolCall.FunctionBody(
|
||||||
|
... name="function",
|
||||||
|
... arguments="{}"
|
||||||
|
... ),
|
||||||
|
... ).model_dump()
|
||||||
|
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FunctionBody(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str | None
|
||||||
|
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""The ID of the tool call."""
|
||||||
|
function: FunctionBody
|
||||||
|
"""The function body of the tool call."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallPart(BaseModel):
|
||||||
|
"""A part of the tool call."""
|
||||||
|
|
||||||
|
arguments_part: str | None = None
|
||||||
|
"""A part of the arguments of the tool call."""
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""A message in a conversation."""
|
||||||
|
|
||||||
|
role: Literal[
|
||||||
|
"system",
|
||||||
|
"user",
|
||||||
|
"assistant",
|
||||||
|
"tool",
|
||||||
|
]
|
||||||
|
|
||||||
|
content: str | list[ContentPart]
|
||||||
|
"""The content of the message."""
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantMessageSegment(Message):
|
||||||
|
"""A message segment from the assistant."""
|
||||||
|
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallMessageSegment(Message):
|
||||||
|
"""A message segment representing a tool call."""
|
||||||
|
|
||||||
|
role: Literal["tool"] = "tool"
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserMessageSegment(Message):
|
||||||
|
"""A message segment from the user."""
|
||||||
|
|
||||||
|
role: Literal["user"] = "user"
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageSegment(Message):
|
||||||
|
"""A message segment from the system."""
|
||||||
|
|
||||||
|
role: Literal["system"] = "system"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
import typing as T
|
import typing as T
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Generic
|
from typing import Any, Generic
|
||||||
from typing_extensions import TypeVar
|
|
||||||
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
TContext = TypeVar("TContext", default=Any)
|
TContext = TypeVar("TContext", default=Any)
|
||||||
|
|
||||||
@@ -12,7 +11,7 @@ class ContextWrapper(Generic[TContext]):
|
|||||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||||
|
|
||||||
context: TContext
|
context: TContext
|
||||||
event: AstrMessageEvent
|
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||||
|
|
||||||
|
|
||||||
NoContext = ContextWrapper[None]
|
NoContext = ContextWrapper[None]
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
import abc
|
import abc
|
||||||
import typing as T
|
import typing as T
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from ..run_context import ContextWrapper, TContext
|
|
||||||
from ..response import AgentResponse
|
|
||||||
from ..hooks import BaseAgentRunHooks
|
|
||||||
from ..tool_executor import BaseFunctionToolExecutor
|
|
||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.provider.entities import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
|
|
||||||
|
from ..hooks import BaseAgentRunHooks
|
||||||
|
from ..response import AgentResponse
|
||||||
|
from ..run_context import ContextWrapper, TContext
|
||||||
|
from ..tool_executor import BaseFunctionToolExecutor
|
||||||
|
|
||||||
|
|
||||||
class AgentState(Enum):
|
class AgentState(Enum):
|
||||||
"""Defines the state of the agent."""
|
"""Defines the state of the agent."""
|
||||||
@@ -28,31 +30,26 @@ class BaseAgentRunner(T.Generic[TContext]):
|
|||||||
agent_hooks: BaseAgentRunHooks[TContext],
|
agent_hooks: BaseAgentRunHooks[TContext],
|
||||||
**kwargs: T.Any,
|
**kwargs: T.Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Reset the agent to its initial state.
|
||||||
Reset the agent to its initial state.
|
|
||||||
This method should be called before starting a new run.
|
This method should be called before starting a new run.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
||||||
"""
|
"""Process a single step of the agent."""
|
||||||
Process a single step of the agent.
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""
|
"""Check if the agent has completed its task.
|
||||||
Check if the agent has completed its task.
|
|
||||||
Returns True if the agent is done, False otherwise.
|
Returns True if the agent is done, False otherwise.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||||
"""
|
"""Get the final observation from the agent.
|
||||||
Get the final observation from the agent.
|
|
||||||
This method should be called after the agent is done.
|
This method should be called after the agent is done.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -1,31 +1,33 @@
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import typing as T
|
import typing as T
|
||||||
from .base import BaseAgentRunner, AgentResponse, AgentState
|
|
||||||
from ..hooks import BaseAgentRunHooks
|
from mcp.types import (
|
||||||
from ..tool_executor import BaseFunctionToolExecutor
|
BlobResourceContents,
|
||||||
from ..run_context import ContextWrapper, TContext
|
CallToolResult,
|
||||||
from ..response import AgentResponseData
|
EmbeddedResource,
|
||||||
from astrbot.core.provider.provider import Provider
|
ImageContent,
|
||||||
|
TextContent,
|
||||||
|
TextResourceContents,
|
||||||
|
)
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
)
|
)
|
||||||
from astrbot.core.provider.entities import (
|
from astrbot.core.provider.entities import (
|
||||||
ProviderRequest,
|
|
||||||
LLMResponse,
|
LLMResponse,
|
||||||
ToolCallMessageSegment,
|
ProviderRequest,
|
||||||
AssistantMessageSegment,
|
|
||||||
ToolCallsResult,
|
ToolCallsResult,
|
||||||
)
|
)
|
||||||
from mcp.types import (
|
from astrbot.core.provider.provider import Provider
|
||||||
TextContent,
|
|
||||||
ImageContent,
|
from ..hooks import BaseAgentRunHooks
|
||||||
EmbeddedResource,
|
from ..message import AssistantMessageSegment, ToolCallMessageSegment
|
||||||
TextResourceContents,
|
from ..response import AgentResponseData
|
||||||
BlobResourceContents,
|
from ..run_context import ContextWrapper, TContext
|
||||||
CallToolResult,
|
from ..tool_executor import BaseFunctionToolExecutor
|
||||||
)
|
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||||
from astrbot import logger
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -70,8 +72,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def step(self):
|
async def step(self):
|
||||||
"""
|
"""Process a single step of the agent.
|
||||||
Process a single step of the agent.
|
|
||||||
This method should return the result of the step.
|
This method should return the result of the step.
|
||||||
"""
|
"""
|
||||||
if not self.req:
|
if not self.req:
|
||||||
@@ -99,7 +100,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="streaming_delta",
|
type="streaming_delta",
|
||||||
data=AgentResponseData(
|
data=AgentResponseData(
|
||||||
chain=MessageChain().message(llm_response.completion_text)
|
chain=MessageChain().message(llm_response.completion_text),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -120,8 +121,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
type="err",
|
type="err",
|
||||||
data=AgentResponseData(
|
data=AgentResponseData(
|
||||||
chain=MessageChain().message(
|
chain=MessageChain().message(
|
||||||
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
|
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}",
|
||||||
)
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -144,7 +145,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="llm_result",
|
type="llm_result",
|
||||||
data=AgentResponseData(
|
data=AgentResponseData(
|
||||||
chain=MessageChain().message(llm_resp.completion_text)
|
chain=MessageChain().message(llm_resp.completion_text),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -155,7 +156,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="tool_call",
|
type="tool_call",
|
||||||
data=AgentResponseData(
|
data=AgentResponseData(
|
||||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||||
@@ -169,8 +170,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
# 将结果添加到上下文中
|
# 将结果添加到上下文中
|
||||||
tool_calls_result = ToolCallsResult(
|
tool_calls_result = ToolCallsResult(
|
||||||
tool_calls_info=AssistantMessageSegment(
|
tool_calls_info=AssistantMessageSegment(
|
||||||
role="assistant",
|
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||||
tool_calls=llm_resp.to_openai_tool_calls(),
|
|
||||||
content=llm_resp.completion_text,
|
content=llm_resp.completion_text,
|
||||||
),
|
),
|
||||||
tool_calls_result=tool_call_result_blocks,
|
tool_calls_result=tool_call_result_blocks,
|
||||||
@@ -205,13 +205,43 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=func_tool_id,
|
tool_call_id=func_tool_id,
|
||||||
content=f"error: 未找到工具 {func_tool_name}",
|
content=f"error: 未找到工具 {func_tool_name}",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
valid_params = {} # 参数过滤:只传递函数实际需要的参数
|
||||||
|
|
||||||
|
# 获取实际的 handler 函数
|
||||||
|
if func_tool.handler:
|
||||||
|
logger.debug(
|
||||||
|
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}",
|
||||||
|
)
|
||||||
|
if func_tool.parameters and func_tool.parameters.get("properties"):
|
||||||
|
expected_params = set(func_tool.parameters["properties"].keys())
|
||||||
|
|
||||||
|
valid_params = {
|
||||||
|
k: v
|
||||||
|
for k, v in func_tool_args.items()
|
||||||
|
if k in expected_params
|
||||||
|
}
|
||||||
|
|
||||||
|
# 记录被忽略的参数
|
||||||
|
ignored_params = set(func_tool_args.keys()) - set(
|
||||||
|
valid_params.keys(),
|
||||||
|
)
|
||||||
|
if ignored_params:
|
||||||
|
logger.warning(
|
||||||
|
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果没有 handler(如 MCP 工具),使用所有参数
|
||||||
|
valid_params = func_tool_args
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.agent_hooks.on_tool_start(
|
await self.agent_hooks.on_tool_start(
|
||||||
self.run_context, func_tool, func_tool_args
|
self.run_context,
|
||||||
|
func_tool,
|
||||||
|
valid_params,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
||||||
@@ -219,7 +249,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
executor = self.tool_executor.execute(
|
executor = self.tool_executor.execute(
|
||||||
tool=func_tool,
|
tool=func_tool,
|
||||||
run_context=self.run_context,
|
run_context=self.run_context,
|
||||||
**func_tool_args,
|
**valid_params, # 只传递有效的参数
|
||||||
)
|
)
|
||||||
|
|
||||||
_final_resp: CallToolResult | None = None
|
_final_resp: CallToolResult | None = None
|
||||||
@@ -233,7 +263,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=func_tool_id,
|
tool_call_id=func_tool_id,
|
||||||
content=res.content[0].text,
|
content=res.content[0].text,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain().message(res.content[0].text)
|
yield MessageChain().message(res.content[0].text)
|
||||||
elif isinstance(res.content[0], ImageContent):
|
elif isinstance(res.content[0], ImageContent):
|
||||||
@@ -242,10 +272,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=func_tool_id,
|
tool_call_id=func_tool_id,
|
||||||
content="返回了图片(已直接发送给用户)",
|
content="返回了图片(已直接发送给用户)",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain(type="tool_direct_result").base64_image(
|
yield MessageChain(type="tool_direct_result").base64_image(
|
||||||
res.content[0].data
|
res.content[0].data,
|
||||||
)
|
)
|
||||||
elif isinstance(res.content[0], EmbeddedResource):
|
elif isinstance(res.content[0], EmbeddedResource):
|
||||||
resource = res.content[0].resource
|
resource = res.content[0].resource
|
||||||
@@ -255,7 +285,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=func_tool_id,
|
tool_call_id=func_tool_id,
|
||||||
content=resource.text,
|
content=resource.text,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain().message(resource.text)
|
yield MessageChain().message(resource.text)
|
||||||
elif (
|
elif (
|
||||||
@@ -268,10 +298,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=func_tool_id,
|
tool_call_id=func_tool_id,
|
||||||
content="返回了图片(已直接发送给用户)",
|
content="返回了图片(已直接发送给用户)",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
type="tool_direct_result"
|
type="tool_direct_result",
|
||||||
).base64_image(resource.blob)
|
).base64_image(resource.blob)
|
||||||
else:
|
else:
|
||||||
tool_call_result_blocks.append(
|
tool_call_result_blocks.append(
|
||||||
@@ -279,41 +309,41 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=func_tool_id,
|
tool_call_id=func_tool_id,
|
||||||
content="返回的数据类型不受支持",
|
content="返回的数据类型不受支持",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain().message("返回的数据类型不受支持。")
|
yield MessageChain().message("返回的数据类型不受支持。")
|
||||||
|
|
||||||
elif resp is None:
|
elif resp is None:
|
||||||
# Tool 直接请求发送消息给用户
|
# Tool 直接请求发送消息给用户
|
||||||
# 这里我们将直接结束 Agent Loop。
|
# 这里我们将直接结束 Agent Loop。
|
||||||
|
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||||
|
logger.warning(
|
||||||
|
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||||
|
)
|
||||||
self._transition_state(AgentState.DONE)
|
self._transition_state(AgentState.DONE)
|
||||||
if res := self.run_context.event.get_result():
|
|
||||||
if res.chain:
|
|
||||||
yield MessageChain(
|
|
||||||
chain=res.chain, type="tool_direct_result"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 不应该出现其他类型
|
# 不应该出现其他类型
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.agent_hooks.on_tool_end(
|
await self.agent_hooks.on_tool_end(
|
||||||
self.run_context, func_tool, func_tool_args, _final_resp
|
self.run_context,
|
||||||
|
func_tool,
|
||||||
|
func_tool_args,
|
||||||
|
_final_resp,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||||
|
|
||||||
self.run_context.event.clear_result()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(traceback.format_exc())
|
logger.warning(traceback.format_exc())
|
||||||
tool_call_result_blocks.append(
|
tool_call_result_blocks.append(
|
||||||
ToolCallMessageSegment(
|
ToolCallMessageSegment(
|
||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=func_tool_id,
|
tool_call_id=func_tool_id,
|
||||||
content=f"error: {str(e)}",
|
content=f"error: {e!s}",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理函数调用响应
|
# 处理函数调用响应
|
||||||
|
|||||||
@@ -1,55 +1,75 @@
|
|||||||
from dataclasses import dataclass
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, Generic
|
||||||
|
|
||||||
|
import jsonschema
|
||||||
|
import mcp
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from typing import Awaitable, Callable, Literal, Any, Optional
|
from pydantic import model_validator
|
||||||
from .mcp_client import MCPClient
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from .run_context import ContextWrapper, TContext
|
||||||
|
|
||||||
|
ParametersType = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FunctionTool:
|
class ToolSchema:
|
||||||
"""A class representing a function tool that can be used in function calling."""
|
"""A class representing the schema of a tool for function calling."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
parameters: dict | None = None
|
"""The name of the tool."""
|
||||||
description: str | None = None
|
|
||||||
handler: Callable[..., Awaitable[Any]] | None = None
|
|
||||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
|
||||||
handler_module_path: str | None = None
|
|
||||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
|
||||||
|
|
||||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
description: str
|
||||||
|
"""The description of the tool."""
|
||||||
|
|
||||||
|
parameters: ParametersType
|
||||||
|
"""The parameters of the tool, in JSON Schema format."""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_parameters(self) -> "ToolSchema":
|
||||||
|
jsonschema.validate(
|
||||||
|
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||||
|
"""A callable tool, for function calling."""
|
||||||
|
|
||||||
|
handler: Callable[..., Awaitable[Any]] | None = None
|
||||||
|
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||||
|
|
||||||
|
handler_module_path: str | None = None
|
||||||
|
"""
|
||||||
|
The module path of the handler function. This is empty when the origin is mcp.
|
||||||
|
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
|
||||||
|
causing the handler's __module__ to be functools
|
||||||
"""
|
"""
|
||||||
active: bool = True
|
active: bool = True
|
||||||
"""是否激活"""
|
"""
|
||||||
|
Whether the tool is active. This field is a special field for AstrBot.
|
||||||
origin: Literal["local", "mcp"] = "local"
|
You can ignore it when integrating with other frameworks.
|
||||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
"""
|
||||||
|
|
||||||
# MCP 相关字段
|
|
||||||
mcp_server_name: str | None = None
|
|
||||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
|
||||||
mcp_client: MCPClient | None = None
|
|
||||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||||
|
|
||||||
def __dict__(self) -> dict[str, Any]:
|
async def call(
|
||||||
"""将 FunctionTool 转换为字典格式"""
|
self, context: ContextWrapper[TContext], **kwargs
|
||||||
return {
|
) -> str | mcp.types.CallToolResult:
|
||||||
"name": self.name,
|
"""Run the tool with the given arguments. The handler field has priority."""
|
||||||
"parameters": self.parameters,
|
raise NotImplementedError(
|
||||||
"description": self.description,
|
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
||||||
"active": self.active,
|
)
|
||||||
"origin": self.origin,
|
|
||||||
"mcp_server_name": self.mcp_server_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ToolSet:
|
class ToolSet:
|
||||||
"""A set of function tools that can be used in function calling.
|
"""A set of function tools that can be used in function calling.
|
||||||
|
|
||||||
This class provides methods to add, remove, and retrieve tools, as well as
|
This class provides methods to add, remove, and retrieve tools, as well as
|
||||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, tools: list[FunctionTool] | None = None):
|
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||||
self.tools: list[FunctionTool] = tools or []
|
self.tools: list[FunctionTool] = tools or []
|
||||||
@@ -71,7 +91,7 @@ class ToolSet:
|
|||||||
"""Remove a tool by its name."""
|
"""Remove a tool by its name."""
|
||||||
self.tools = [tool for tool in self.tools if tool.name != name]
|
self.tools = [tool for tool in self.tools if tool.name != name]
|
||||||
|
|
||||||
def get_tool(self, name: str) -> Optional[FunctionTool]:
|
def get_tool(self, name: str) -> FunctionTool | None:
|
||||||
"""Get a tool by its name."""
|
"""Get a tool by its name."""
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
if tool.name == name:
|
if tool.name == name:
|
||||||
@@ -132,10 +152,8 @@ class ToolSet:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
tool.parameters
|
tool.parameters and tool.parameters.get("properties")
|
||||||
and tool.parameters.get("properties")
|
) or not omit_empty_parameter_field:
|
||||||
or not omit_empty_parameter_field
|
|
||||||
):
|
|
||||||
func_def["function"]["parameters"] = tool.parameters
|
func_def["function"]["parameters"] = tool.parameters
|
||||||
|
|
||||||
result.append(func_def)
|
result.append(func_def)
|
||||||
@@ -185,7 +203,8 @@ class ToolSet:
|
|||||||
if "type" in schema and schema["type"] in supported_types:
|
if "type" in schema and schema["type"] in supported_types:
|
||||||
result["type"] = schema["type"]
|
result["type"] = schema["type"]
|
||||||
if "format" in schema and schema["format"] in supported_formats.get(
|
if "format" in schema and schema["format"] in supported_formats.get(
|
||||||
result["type"], set()
|
result["type"],
|
||||||
|
set(),
|
||||||
):
|
):
|
||||||
result["format"] = schema["format"]
|
result["format"] = schema["format"]
|
||||||
else:
|
else:
|
||||||
@@ -222,7 +241,7 @@ class ToolSet:
|
|||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
d = {
|
d: dict[str, Any] = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any, Generic
|
||||||
|
|
||||||
import mcp
|
import mcp
|
||||||
from typing import Any, Generic, AsyncGenerator
|
|
||||||
from .run_context import TContext, ContextWrapper
|
from .run_context import ContextWrapper, TContext
|
||||||
from .tool import FunctionTool
|
from .tool import FunctionTool
|
||||||
|
|
||||||
|
|
||||||
class BaseFunctionToolExecutor(Generic[TContext]):
|
class BaseFunctionToolExecutor(Generic[TContext]):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
|
cls,
|
||||||
|
tool: FunctionTool,
|
||||||
|
run_context: ContextWrapper[TContext],
|
||||||
|
**tool_args,
|
||||||
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...
|
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.provider.entities import ProviderRequest
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
|
|
||||||
@@ -9,4 +11,4 @@ class AstrAgentContext:
|
|||||||
first_provider_request: ProviderRequest
|
first_provider_request: ProviderRequest
|
||||||
curr_provider_request: ProviderRequest
|
curr_provider_request: ProviderRequest
|
||||||
streaming: bool
|
streaming: bool
|
||||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
event: AstrMessageEvent
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import TypedDict, TypeVar
|
||||||
|
|
||||||
from astrbot.core import AstrBotConfig, logger
|
from astrbot.core import AstrBotConfig, logger
|
||||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
|
||||||
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
||||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||||
from astrbot.core.platform.message_session import MessageSession
|
from astrbot.core.platform.message_session import MessageSession
|
||||||
|
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
||||||
from typing import TypeVar, TypedDict
|
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||||
|
|
||||||
_VT = TypeVar("_VT")
|
_VT = TypeVar("_VT")
|
||||||
|
|
||||||
@@ -15,14 +17,12 @@ class ConfInfo(TypedDict):
|
|||||||
"""Configuration information for a specific session or platform."""
|
"""Configuration information for a specific session or platform."""
|
||||||
|
|
||||||
id: str # UUID of the configuration or "default"
|
id: str # UUID of the configuration or "default"
|
||||||
umop: list[str] # Unified Message Origin Pattern
|
|
||||||
name: str
|
name: str
|
||||||
path: str # File name to the configuration file
|
path: str # File name to the configuration file
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||||
id="default",
|
id="default",
|
||||||
umop=["::"],
|
|
||||||
name="default",
|
name="default",
|
||||||
path=ASTRBOT_CONFIG_PATH,
|
path=ASTRBOT_CONFIG_PATH,
|
||||||
)
|
)
|
||||||
@@ -31,8 +31,14 @@ DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
|||||||
class AstrBotConfigManager:
|
class AstrBotConfigManager:
|
||||||
"""A class to manage the system configuration of AstrBot, aka ACM"""
|
"""A class to manage the system configuration of AstrBot, aka ACM"""
|
||||||
|
|
||||||
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
|
def __init__(
|
||||||
|
self,
|
||||||
|
default_config: AstrBotConfig,
|
||||||
|
ucr: UmopConfigRouter,
|
||||||
|
sp: SharedPreferences,
|
||||||
|
):
|
||||||
self.sp = sp
|
self.sp = sp
|
||||||
|
self.ucr = ucr
|
||||||
self.confs: dict[str, AstrBotConfig] = {}
|
self.confs: dict[str, AstrBotConfig] = {}
|
||||||
"""uuid / "default" -> AstrBotConfig"""
|
"""uuid / "default" -> AstrBotConfig"""
|
||||||
self.confs["default"] = default_config
|
self.confs["default"] = default_config
|
||||||
@@ -43,7 +49,10 @@ class AstrBotConfigManager:
|
|||||||
"""获取所有的 abconf 数据"""
|
"""获取所有的 abconf 数据"""
|
||||||
if self.abconf_data is None:
|
if self.abconf_data is None:
|
||||||
self.abconf_data = self.sp.get(
|
self.abconf_data = self.sp.get(
|
||||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
"abconf_mapping",
|
||||||
|
{},
|
||||||
|
scope="global",
|
||||||
|
scope_id="global",
|
||||||
)
|
)
|
||||||
return self.abconf_data
|
return self.abconf_data
|
||||||
|
|
||||||
@@ -59,28 +68,20 @@ class AstrBotConfigManager:
|
|||||||
self.confs[uuid_] = conf
|
self.confs[uuid_] = conf
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
|
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping.",
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
|
||||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
|
||||||
p1_ls = p1.split(":")
|
|
||||||
p2_ls = p2.split(":")
|
|
||||||
|
|
||||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
|
||||||
return False # 非法格式
|
|
||||||
|
|
||||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
|
||||||
|
|
||||||
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
||||||
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# uuid -> { "umop": list, "path": str, "name": str }
|
# uuid -> { "path": str, "name": str }
|
||||||
abconf_data = self._get_abconf_data()
|
abconf_data = self._get_abconf_data()
|
||||||
|
|
||||||
if isinstance(umo, MessageSession):
|
if isinstance(umo, MessageSession):
|
||||||
umo = str(umo)
|
umo = str(umo)
|
||||||
else:
|
else:
|
||||||
@@ -89,10 +90,13 @@ class AstrBotConfigManager:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return DEFAULT_CONFIG_CONF_INFO
|
return DEFAULT_CONFIG_CONF_INFO
|
||||||
|
|
||||||
for uuid_, meta in abconf_data.items():
|
conf_id = self.ucr.get_conf_id_for_umop(umo)
|
||||||
for pattern in meta["umop"]:
|
if conf_id:
|
||||||
if self._is_umo_match(pattern, umo):
|
meta = abconf_data.get(conf_id)
|
||||||
return ConfInfo(**meta, id=uuid_)
|
if meta and isinstance(meta, dict):
|
||||||
|
# the bind relation between umo and conf is defined in ucr now, so we remove "umop" here
|
||||||
|
meta.pop("umop", None)
|
||||||
|
return ConfInfo(**meta, id=conf_id)
|
||||||
|
|
||||||
return DEFAULT_CONFIG_CONF_INFO
|
return DEFAULT_CONFIG_CONF_INFO
|
||||||
|
|
||||||
@@ -100,23 +104,17 @@ class AstrBotConfigManager:
|
|||||||
self,
|
self,
|
||||||
abconf_path: str,
|
abconf_path: str,
|
||||||
abconf_id: str,
|
abconf_id: str,
|
||||||
umo_parts: list[str] | list[MessageSession],
|
|
||||||
abconf_name: str | None = None,
|
abconf_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""保存配置文件的映射关系"""
|
"""保存配置文件的映射关系"""
|
||||||
for part in umo_parts:
|
|
||||||
if isinstance(part, MessageSession):
|
|
||||||
part = str(part)
|
|
||||||
elif not isinstance(part, str):
|
|
||||||
raise ValueError(
|
|
||||||
"umo_parts must be a list of strings or MessageSession instances"
|
|
||||||
)
|
|
||||||
abconf_data = self.sp.get(
|
abconf_data = self.sp.get(
|
||||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
"abconf_mapping",
|
||||||
|
{},
|
||||||
|
scope="global",
|
||||||
|
scope_id="global",
|
||||||
)
|
)
|
||||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||||
abconf_data[abconf_id] = {
|
abconf_data[abconf_id] = {
|
||||||
"umop": umo_parts,
|
|
||||||
"path": abconf_path,
|
"path": abconf_path,
|
||||||
"name": random_word,
|
"name": random_word,
|
||||||
}
|
}
|
||||||
@@ -153,29 +151,26 @@ class AstrBotConfigManager:
|
|||||||
def get_conf_list(self) -> list[ConfInfo]:
|
def get_conf_list(self) -> list[ConfInfo]:
|
||||||
"""获取所有配置文件的元数据列表"""
|
"""获取所有配置文件的元数据列表"""
|
||||||
conf_list = []
|
conf_list = []
|
||||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
|
||||||
abconf_mapping = self._get_abconf_data()
|
abconf_mapping = self._get_abconf_data()
|
||||||
for uuid_, meta in abconf_mapping.items():
|
for uuid_, meta in abconf_mapping.items():
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
continue
|
||||||
|
meta.pop("umop", None)
|
||||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||||
|
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||||
return conf_list
|
return conf_list
|
||||||
|
|
||||||
def create_conf(
|
def create_conf(
|
||||||
self,
|
self,
|
||||||
umo_parts: list[str] | list[MessageSession],
|
|
||||||
config: dict = DEFAULT_CONFIG,
|
config: dict = DEFAULT_CONFIG,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
|
||||||
|
|
||||||
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
|
||||||
"""
|
|
||||||
conf_uuid = str(uuid.uuid4())
|
conf_uuid = str(uuid.uuid4())
|
||||||
conf_file_name = f"abconf_{conf_uuid}.json"
|
conf_file_name = f"abconf_{conf_uuid}.json"
|
||||||
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
|
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
|
||||||
conf = AstrBotConfig(config_path=conf_path, default_config=config)
|
conf = AstrBotConfig(config_path=conf_path, default_config=config)
|
||||||
conf.save_config()
|
conf.save_config()
|
||||||
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
|
self._save_conf_mapping(conf_file_name, conf_uuid, abconf_name=name)
|
||||||
self.confs[conf_uuid] = conf
|
self.confs[conf_uuid] = conf
|
||||||
return conf_uuid
|
return conf_uuid
|
||||||
|
|
||||||
@@ -190,13 +185,17 @@ class AstrBotConfigManager:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果试图删除默认配置文件
|
ValueError: 如果试图删除默认配置文件
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if conf_id == "default":
|
if conf_id == "default":
|
||||||
raise ValueError("不能删除默认配置文件")
|
raise ValueError("不能删除默认配置文件")
|
||||||
|
|
||||||
# 从映射中移除
|
# 从映射中移除
|
||||||
abconf_data = self.sp.get(
|
abconf_data = self.sp.get(
|
||||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
"abconf_mapping",
|
||||||
|
{},
|
||||||
|
scope="global",
|
||||||
|
scope_id="global",
|
||||||
)
|
)
|
||||||
if conf_id not in abconf_data:
|
if conf_id not in abconf_data:
|
||||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||||
@@ -204,7 +203,8 @@ class AstrBotConfigManager:
|
|||||||
|
|
||||||
# 获取配置文件路径
|
# 获取配置文件路径
|
||||||
conf_path = os.path.join(
|
conf_path = os.path.join(
|
||||||
get_astrbot_config_path(), abconf_data[conf_id]["path"]
|
get_astrbot_config_path(),
|
||||||
|
abconf_data[conf_id]["path"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 删除配置文件
|
# 删除配置文件
|
||||||
@@ -228,24 +228,25 @@ class AstrBotConfigManager:
|
|||||||
logger.info(f"成功删除配置文件 {conf_id}")
|
logger.info(f"成功删除配置文件 {conf_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def update_conf_info(
|
def update_conf_info(self, conf_id: str, name: str | None = None) -> bool:
|
||||||
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
|
|
||||||
) -> bool:
|
|
||||||
"""更新配置文件信息
|
"""更新配置文件信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conf_id: 配置文件的 UUID
|
conf_id: 配置文件的 UUID
|
||||||
name: 新的配置文件名称 (可选)
|
name: 新的配置文件名称 (可选)
|
||||||
umo_parts: 新的 UMO 部分列表 (可选)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 更新是否成功
|
bool: 更新是否成功
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if conf_id == "default":
|
if conf_id == "default":
|
||||||
raise ValueError("不能更新默认配置文件的信息")
|
raise ValueError("不能更新默认配置文件的信息")
|
||||||
|
|
||||||
abconf_data = self.sp.get(
|
abconf_data = self.sp.get(
|
||||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
"abconf_mapping",
|
||||||
|
{},
|
||||||
|
scope="global",
|
||||||
|
scope_id="global",
|
||||||
)
|
)
|
||||||
if conf_id not in abconf_data:
|
if conf_id not in abconf_data:
|
||||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||||
@@ -255,18 +256,6 @@ class AstrBotConfigManager:
|
|||||||
if name is not None:
|
if name is not None:
|
||||||
abconf_data[conf_id]["name"] = name
|
abconf_data[conf_id]["name"] = name
|
||||||
|
|
||||||
# 更新 UMO 部分
|
|
||||||
if umo_parts is not None:
|
|
||||||
# 验证 UMO 部分格式
|
|
||||||
for part in umo_parts:
|
|
||||||
if isinstance(part, MessageSession):
|
|
||||||
part = str(part)
|
|
||||||
elif not isinstance(part, str):
|
|
||||||
raise ValueError(
|
|
||||||
"umo_parts must be a list of strings or MessageSession instances"
|
|
||||||
)
|
|
||||||
abconf_data[conf_id]["umop"] = umo_parts
|
|
||||||
|
|
||||||
# 保存更新
|
# 保存更新
|
||||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||||
self.abconf_data = abconf_data
|
self.abconf_data = abconf_data
|
||||||
@@ -274,7 +263,10 @@ class AstrBotConfigManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def g(
|
def g(
|
||||||
self, umo: str | None = None, key: str | None = None, default: _VT = None
|
self,
|
||||||
|
umo: str | None = None,
|
||||||
|
key: str | None = None,
|
||||||
|
default: _VT = None,
|
||||||
) -> _VT:
|
) -> _VT:
|
||||||
"""获取配置项。umo 为 None 时使用默认配置"""
|
"""获取配置项。umo 为 None 时使用默认配置"""
|
||||||
if umo is None:
|
if umo is None:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
|
|
||||||
from .astrbot_config import *
|
from .astrbot_config import *
|
||||||
|
from .default import DB_PATH, DEFAULT_CONFIG, VERSION
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"DB_PATH",
|
||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
"VERSION",
|
"VERSION",
|
||||||
"DB_PATH",
|
|
||||||
"AstrBotConfig",
|
"AstrBotConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import os
|
import enum
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import enum
|
import os
|
||||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
|
||||||
from typing import Dict
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
|
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||||
|
|
||||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||||
logger = logging.getLogger("astrbot")
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
@@ -27,7 +28,7 @@ class AstrBotConfig(dict):
|
|||||||
self,
|
self,
|
||||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||||
default_config: dict = DEFAULT_CONFIG,
|
default_config: dict = DEFAULT_CONFIG,
|
||||||
schema: dict = None,
|
schema: dict | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -45,7 +46,7 @@ class AstrBotConfig(dict):
|
|||||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||||
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
|
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
|
||||||
|
|
||||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
with open(config_path, encoding="utf-8-sig") as f:
|
||||||
conf_str = f.read()
|
conf_str = f.read()
|
||||||
conf = json.loads(conf_str)
|
conf = json.loads(conf_str)
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ class AstrBotConfig(dict):
|
|||||||
for k, v in schema.items():
|
for k, v in schema.items():
|
||||||
if v["type"] not in DEFAULT_VALUE_MAP:
|
if v["type"] not in DEFAULT_VALUE_MAP:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}"
|
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}",
|
||||||
)
|
)
|
||||||
if "default" in v:
|
if "default" in v:
|
||||||
default = v["default"]
|
default = v["default"]
|
||||||
@@ -82,7 +83,7 @@ class AstrBotConfig(dict):
|
|||||||
|
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
def check_config_integrity(self, refer_conf: dict, conf: dict, path=""):
|
||||||
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
|
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
|
||||||
has_new = False
|
has_new = False
|
||||||
|
|
||||||
@@ -97,27 +98,28 @@ class AstrBotConfig(dict):
|
|||||||
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
|
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
|
||||||
new_conf[key] = value
|
new_conf[key] = value
|
||||||
has_new = True
|
has_new = True
|
||||||
else:
|
elif conf[key] is None:
|
||||||
if conf[key] is None:
|
# 配置项为 None,使用默认值
|
||||||
# 配置项为 None,使用默认值
|
new_conf[key] = value
|
||||||
|
has_new = True
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
# 递归检查子配置项
|
||||||
|
if not isinstance(conf[key], dict):
|
||||||
|
# 类型不匹配,使用默认值
|
||||||
new_conf[key] = value
|
new_conf[key] = value
|
||||||
has_new = True
|
has_new = True
|
||||||
elif isinstance(value, dict):
|
|
||||||
# 递归检查子配置项
|
|
||||||
if not isinstance(conf[key], dict):
|
|
||||||
# 类型不匹配,使用默认值
|
|
||||||
new_conf[key] = value
|
|
||||||
has_new = True
|
|
||||||
else:
|
|
||||||
# 递归检查并同步顺序
|
|
||||||
child_has_new = self.check_config_integrity(
|
|
||||||
value, conf[key], path + "." + key if path else key
|
|
||||||
)
|
|
||||||
new_conf[key] = conf[key]
|
|
||||||
has_new |= child_has_new
|
|
||||||
else:
|
else:
|
||||||
# 直接使用现有配置
|
# 递归检查并同步顺序
|
||||||
|
child_has_new = self.check_config_integrity(
|
||||||
|
value,
|
||||||
|
conf[key],
|
||||||
|
path + "." + key if path else key,
|
||||||
|
)
|
||||||
new_conf[key] = conf[key]
|
new_conf[key] = conf[key]
|
||||||
|
has_new |= child_has_new
|
||||||
|
else:
|
||||||
|
# 直接使用现有配置
|
||||||
|
new_conf[key] = conf[key]
|
||||||
|
|
||||||
# 检查是否存在参考配置中没有的配置项
|
# 检查是否存在参考配置中没有的配置项
|
||||||
for key in list(conf.keys()):
|
for key in list(conf.keys()):
|
||||||
@@ -140,7 +142,7 @@ class AstrBotConfig(dict):
|
|||||||
|
|
||||||
return has_new
|
return has_new
|
||||||
|
|
||||||
def save_config(self, replace_config: Dict = None):
|
def save_config(self, replace_config: dict | None = None):
|
||||||
"""将配置写入文件
|
"""将配置写入文件
|
||||||
|
|
||||||
如果传入 replace_config,则将配置替换为 replace_config
|
如果传入 replace_config,则将配置替换为 replace_config
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
"""
|
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
|
||||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.cdefaore.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.3.5"
|
VERSION = "4.5.4"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
@@ -134,8 +132,11 @@ DEFAULT_CONFIG = {
|
|||||||
"persona": [], # deprecated
|
"persona": [], # deprecated
|
||||||
"timezone": "Asia/Shanghai",
|
"timezone": "Asia/Shanghai",
|
||||||
"callback_api_base": "",
|
"callback_api_base": "",
|
||||||
"default_kb_collection": "", # 默认知识库名称
|
"default_kb_collection": "", # 默认知识库名称, 已经过时
|
||||||
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
|
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
|
||||||
|
"kb_names": [], # 默认知识库名称列表
|
||||||
|
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||||
|
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -162,10 +163,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"enable": False,
|
"enable": False,
|
||||||
"appid": "",
|
"appid": "",
|
||||||
"secret": "",
|
"secret": "",
|
||||||
|
"is_sandbox": False,
|
||||||
"callback_server_host": "0.0.0.0",
|
"callback_server_host": "0.0.0.0",
|
||||||
"port": 6196,
|
"port": 6196,
|
||||||
},
|
},
|
||||||
"QQ 个人号(aiocqhttp)": {
|
"QQ 个人号(OneBot v11)": {
|
||||||
"id": "default",
|
"id": "default",
|
||||||
"type": "aiocqhttp",
|
"type": "aiocqhttp",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -173,7 +175,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"ws_reverse_port": 6199,
|
"ws_reverse_port": 6199,
|
||||||
"ws_reverse_token": "",
|
"ws_reverse_token": "",
|
||||||
},
|
},
|
||||||
"微信个人号(WeChatPadPro)": {
|
"WeChatPadPro": {
|
||||||
"id": "wechatpadpro",
|
"id": "wechatpadpro",
|
||||||
"type": "wechatpadpro",
|
"type": "wechatpadpro",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -268,6 +270,14 @@ CONFIG_METADATA_2 = {
|
|||||||
"misskey_default_visibility": "public",
|
"misskey_default_visibility": "public",
|
||||||
"misskey_local_only": False,
|
"misskey_local_only": False,
|
||||||
"misskey_enable_chat": True,
|
"misskey_enable_chat": True,
|
||||||
|
# download / security options
|
||||||
|
"misskey_allow_insecure_downloads": False,
|
||||||
|
"misskey_download_timeout": 15,
|
||||||
|
"misskey_download_chunk_size": 65536,
|
||||||
|
"misskey_max_download_bytes": None,
|
||||||
|
"misskey_enable_file_upload": True,
|
||||||
|
"misskey_upload_concurrency": 3,
|
||||||
|
"misskey_upload_folder": "",
|
||||||
},
|
},
|
||||||
"Slack": {
|
"Slack": {
|
||||||
"id": "slack",
|
"id": "slack",
|
||||||
@@ -292,8 +302,30 @@ CONFIG_METADATA_2 = {
|
|||||||
"satori_heartbeat_interval": 10,
|
"satori_heartbeat_interval": 10,
|
||||||
"satori_reconnect_delay": 5,
|
"satori_reconnect_delay": 5,
|
||||||
},
|
},
|
||||||
|
# "WebChat": {
|
||||||
|
# "id": "webchat",
|
||||||
|
# "type": "webchat",
|
||||||
|
# "enable": False,
|
||||||
|
# "webchat_link_path": "",
|
||||||
|
# "webchat_present_type": "fullscreen",
|
||||||
|
# },
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
|
# "webchat_link_path": {
|
||||||
|
# "description": "链接路径",
|
||||||
|
# "_special": "webchat_link_path",
|
||||||
|
# "type": "string",
|
||||||
|
# },
|
||||||
|
# "webchat_present_type": {
|
||||||
|
# "_special": "webchat_present_type",
|
||||||
|
# "description": "展现形式",
|
||||||
|
# "type": "string",
|
||||||
|
# "options": ["fullscreen", "embedded"],
|
||||||
|
# },
|
||||||
|
"is_sandbox": {
|
||||||
|
"description": "沙箱模式",
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
"satori_api_base_url": {
|
"satori_api_base_url": {
|
||||||
"description": "Satori API 终结点",
|
"description": "Satori API 终结点",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -396,6 +428,41 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
||||||
},
|
},
|
||||||
|
"misskey_enable_file_upload": {
|
||||||
|
"description": "启用文件上传到 Misskey",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,适配器会尝试将消息链中的文件上传到 Misskey。URL 文件会先尝试服务器端上传,异步上传失败时会回退到下载后本地上传。",
|
||||||
|
},
|
||||||
|
"misskey_allow_insecure_downloads": {
|
||||||
|
"description": "允许不安全下载(禁用 SSL 验证)",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "当远端服务器存在证书问题导致无法正常下载时,自动禁用 SSL 验证作为回退方案。适用于某些图床的证书配置问题。启用有安全风险,仅在必要时使用。",
|
||||||
|
},
|
||||||
|
"misskey_download_timeout": {
|
||||||
|
"description": "远端下载超时时间(秒)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "下载远程文件时的超时时间(秒),用于异步上传回退到本地上传的场景。",
|
||||||
|
},
|
||||||
|
"misskey_download_chunk_size": {
|
||||||
|
"description": "流式下载分块大小(字节)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "流式下载和计算 MD5 时使用的每次读取字节数,过小会增加开销,过大会占用内存。",
|
||||||
|
},
|
||||||
|
"misskey_max_download_bytes": {
|
||||||
|
"description": "最大允许下载字节数(超出则中止)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "如果希望限制下载文件的最大大小以防止 OOM,请填写最大字节数;留空或 null 表示不限制。",
|
||||||
|
},
|
||||||
|
"misskey_upload_concurrency": {
|
||||||
|
"description": "并发上传限制",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "同时进行的文件上传任务上限(整数,默认 3)。",
|
||||||
|
},
|
||||||
|
"misskey_upload_folder": {
|
||||||
|
"description": "上传到网盘的目标文件夹 ID",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "可选:填写 Misskey 网盘中目标文件夹的 ID,上传的文件将放置到该文件夹内。留空则使用账号网盘根目录。",
|
||||||
|
},
|
||||||
"telegram_command_register": {
|
"telegram_command_register": {
|
||||||
"description": "Telegram 命令注册",
|
"description": "Telegram 命令注册",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -447,19 +514,18 @@ CONFIG_METADATA_2 = {
|
|||||||
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
||||||
},
|
},
|
||||||
"ws_reverse_host": {
|
"ws_reverse_host": {
|
||||||
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
|
"description": "反向 Websocket 主机",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
|
"hint": "AstrBot 将作为服务器端。",
|
||||||
},
|
},
|
||||||
"ws_reverse_port": {
|
"ws_reverse_port": {
|
||||||
"description": "反向 Websocket 端口",
|
"description": "反向 Websocket 端口",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
|
||||||
},
|
},
|
||||||
"ws_reverse_token": {
|
"ws_reverse_token": {
|
||||||
"description": "反向 Websocket Token",
|
"description": "反向 Websocket Token",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
"hint": "反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||||
},
|
},
|
||||||
"wecom_ai_bot_name": {
|
"wecom_ai_bot_name": {
|
||||||
"description": "企业微信智能机器人的名字",
|
"description": "企业微信智能机器人的名字",
|
||||||
@@ -703,6 +769,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
"custom_extra_body": {},
|
||||||
|
"xai_native_search": False,
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Anthropic": {
|
"Anthropic": {
|
||||||
@@ -1194,8 +1261,38 @@ CONFIG_METADATA_2 = {
|
|||||||
"rerank_model": "BAAI/bge-reranker-base",
|
"rerank_model": "BAAI/bge-reranker-base",
|
||||||
"timeout": 20,
|
"timeout": 20,
|
||||||
},
|
},
|
||||||
|
"Xinference Rerank": {
|
||||||
|
"id": "xinference_rerank",
|
||||||
|
"type": "xinference_rerank",
|
||||||
|
"provider": "xinference",
|
||||||
|
"provider_type": "rerank",
|
||||||
|
"enable": True,
|
||||||
|
"rerank_api_key": "",
|
||||||
|
"rerank_api_base": "http://127.0.0.1:9997",
|
||||||
|
"rerank_model": "BAAI/bge-reranker-base",
|
||||||
|
"timeout": 20,
|
||||||
|
"launch_model_if_not_running": False,
|
||||||
|
},
|
||||||
|
"Xinference STT": {
|
||||||
|
"id": "xinference_stt",
|
||||||
|
"type": "xinference_stt",
|
||||||
|
"provider": "xinference",
|
||||||
|
"provider_type": "speech_to_text",
|
||||||
|
"enable": False,
|
||||||
|
"api_key": "",
|
||||||
|
"api_base": "http://127.0.0.1:9997",
|
||||||
|
"model": "whisper-large-v3",
|
||||||
|
"timeout": 180,
|
||||||
|
"launch_model_if_not_running": False,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
|
"xai_native_search": {
|
||||||
|
"description": "启用原生搜索功能",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,将通过 xAI 的 Chat Completions 原生 Live Search 进行联网检索(按需计费)。仅对 xAI 提供商生效。",
|
||||||
|
"condition": {"provider": "xai"},
|
||||||
|
},
|
||||||
"rerank_api_base": {
|
"rerank_api_base": {
|
||||||
"description": "重排序模型 API Base URL",
|
"description": "重排序模型 API Base URL",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -1210,6 +1307,11 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "重排序模型名称",
|
"description": "重排序模型名称",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
|
"launch_model_if_not_running": {
|
||||||
|
"description": "模型未运行时自动启动",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
|
||||||
|
},
|
||||||
"modalities": {
|
"modalities": {
|
||||||
"description": "模型能力",
|
"description": "模型能力",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
@@ -1353,6 +1455,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "嵌入维度",
|
"description": "嵌入维度",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
|
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
|
||||||
|
"_special": "get_embedding_dim",
|
||||||
},
|
},
|
||||||
"embedding_model": {
|
"embedding_model": {
|
||||||
"description": "嵌入模型",
|
"description": "嵌入模型",
|
||||||
@@ -2000,6 +2103,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"default_kb_collection": {
|
"default_kb_collection": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
},
|
},
|
||||||
|
"kb_names": {"type": "list", "items": {"type": "string"}},
|
||||||
|
"kb_fusion_top_k": {"type": "int", "default": 20},
|
||||||
|
"kb_final_top_k": {"type": "int", "default": 5},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -2078,10 +2184,22 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "知识库",
|
"description": "知识库",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
"default_kb_collection": {
|
"kb_names": {
|
||||||
"description": "默认使用的知识库",
|
"description": "知识库列表",
|
||||||
"type": "string",
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
"_special": "select_knowledgebase",
|
"_special": "select_knowledgebase",
|
||||||
|
"hint": "支持多选",
|
||||||
|
},
|
||||||
|
"kb_fusion_top_k": {
|
||||||
|
"description": "融合检索结果数",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "多个知识库检索结果融合后的返回结果数量",
|
||||||
|
},
|
||||||
|
"kb_final_top_k": {
|
||||||
|
"description": "最终返回结果数",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2175,7 +2293,7 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_settings.wake_prefix": {
|
"provider_settings.wake_prefix": {
|
||||||
"description": "LLM 聊天额外唤醒前缀 ",
|
"description": "LLM 聊天额外唤醒前缀 ",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "例子: 如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||||
},
|
},
|
||||||
"provider_settings.prompt_prefix": {
|
"provider_settings.prompt_prefix": {
|
||||||
"description": "用户提示词",
|
"description": "用户提示词",
|
||||||
@@ -2587,9 +2705,9 @@ CONFIG_METADATA_3_SYSTEM = {
|
|||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
"""
|
"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库.
|
||||||
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
|
|
||||||
|
|
||||||
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
||||||
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from astrbot.core import sp
|
from astrbot.core import sp
|
||||||
from typing import Dict, List
|
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.db.po import Conversation, ConversationV2
|
from astrbot.core.db.po import Conversation, ConversationV2
|
||||||
|
|
||||||
@@ -16,10 +17,45 @@ class ConversationManager:
|
|||||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||||
|
|
||||||
def __init__(self, db_helper: BaseDatabase):
|
def __init__(self, db_helper: BaseDatabase):
|
||||||
self.session_conversations: Dict[str, str] = {}
|
self.session_conversations: dict[str, str] = {}
|
||||||
self.db = db_helper
|
self.db = db_helper
|
||||||
self.save_interval = 60 # 每 60 秒保存一次
|
self.save_interval = 60 # 每 60 秒保存一次
|
||||||
|
|
||||||
|
# 会话删除回调函数列表(用于级联清理,如知识库配置)
|
||||||
|
self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []
|
||||||
|
|
||||||
|
def register_on_session_deleted(
|
||||||
|
self,
|
||||||
|
callback: Callable[[str], Awaitable[None]],
|
||||||
|
) -> None:
|
||||||
|
"""注册会话删除回调函数.
|
||||||
|
|
||||||
|
其他模块可以注册回调来响应会话删除事件,实现级联清理。
|
||||||
|
例如:知识库模块可以注册回调来清理会话的知识库配置。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._on_session_deleted_callbacks.append(callback)
|
||||||
|
|
||||||
|
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
|
||||||
|
"""触发会话删除回调.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unified_msg_origin: 会话ID
|
||||||
|
|
||||||
|
"""
|
||||||
|
for callback in self._on_session_deleted_callbacks:
|
||||||
|
try:
|
||||||
|
await callback(unified_msg_origin)
|
||||||
|
except Exception as e:
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
|
||||||
|
)
|
||||||
|
|
||||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||||
created_at = int(conv_v2.created_at.timestamp())
|
created_at = int(conv_v2.created_at.timestamp())
|
||||||
@@ -43,12 +79,13 @@ class ConversationManager:
|
|||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
persona_id: str | None = None,
|
persona_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""新建对话,并将当前会话的对话转移到新对话
|
"""新建对话,并将当前会话的对话转移到新对话.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
Returns:
|
Returns:
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not platform_id:
|
if not platform_id:
|
||||||
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
||||||
@@ -74,18 +111,22 @@ class ConversationManager:
|
|||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.session_conversations[unified_msg_origin] = conversation_id
|
self.session_conversations[unified_msg_origin] = conversation_id
|
||||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
||||||
|
|
||||||
async def delete_conversation(
|
async def delete_conversation(
|
||||||
self, unified_msg_origin: str, conversation_id: str | None = None
|
self,
|
||||||
|
unified_msg_origin: str,
|
||||||
|
conversation_id: str | None = None,
|
||||||
):
|
):
|
||||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
@@ -101,11 +142,15 @@ class ConversationManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
|
|
||||||
"""
|
"""
|
||||||
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
||||||
self.session_conversations.pop(unified_msg_origin, None)
|
self.session_conversations.pop(unified_msg_origin, None)
|
||||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||||
|
|
||||||
|
# 触发会话删除回调(级联清理)
|
||||||
|
await self._trigger_session_deleted(unified_msg_origin)
|
||||||
|
|
||||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||||
"""获取会话当前的对话 ID
|
"""获取会话当前的对话 ID
|
||||||
|
|
||||||
@@ -113,6 +158,7 @@ class ConversationManager:
|
|||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
Returns:
|
Returns:
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ret = self.session_conversations.get(unified_msg_origin, None)
|
ret = self.session_conversations.get(unified_msg_origin, None)
|
||||||
if not ret:
|
if not ret:
|
||||||
@@ -127,13 +173,15 @@ class ConversationManager:
|
|||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
create_if_not_exists: bool = False,
|
create_if_not_exists: bool = False,
|
||||||
) -> Conversation | None:
|
) -> Conversation | None:
|
||||||
"""获取会话的对话
|
"""获取会话的对话.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
|
create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
|
||||||
Returns:
|
Returns:
|
||||||
conversation (Conversation): 对话对象
|
conversation (Conversation): 对话对象
|
||||||
|
|
||||||
"""
|
"""
|
||||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||||
if not conv and create_if_not_exists:
|
if not conv and create_if_not_exists:
|
||||||
@@ -146,18 +194,22 @@ class ConversationManager:
|
|||||||
return conv_res
|
return conv_res
|
||||||
|
|
||||||
async def get_conversations(
|
async def get_conversations(
|
||||||
self, unified_msg_origin: str | None = None, platform_id: str | None = None
|
self,
|
||||||
) -> List[Conversation]:
|
unified_msg_origin: str | None = None,
|
||||||
"""获取对话列表
|
platform_id: str | None = None,
|
||||||
|
) -> list[Conversation]:
|
||||||
|
"""获取对话列表.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
||||||
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
||||||
Returns:
|
Returns:
|
||||||
conversations (List[Conversation]): 对话对象列表
|
conversations (List[Conversation]): 对话对象列表
|
||||||
|
|
||||||
"""
|
"""
|
||||||
convs = await self.db.get_conversations(
|
convs = await self.db.get_conversations(
|
||||||
user_id=unified_msg_origin, platform_id=platform_id
|
user_id=unified_msg_origin,
|
||||||
|
platform_id=platform_id,
|
||||||
)
|
)
|
||||||
convs_res = []
|
convs_res = []
|
||||||
for conv in convs:
|
for conv in convs:
|
||||||
@@ -173,7 +225,7 @@ class ConversationManager:
|
|||||||
search_query: str = "",
|
search_query: str = "",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[list[Conversation], int]:
|
) -> tuple[list[Conversation], int]:
|
||||||
"""获取过滤后的对话列表
|
"""获取过滤后的对话列表.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
page (int): 页码, 默认为 1
|
page (int): 页码, 默认为 1
|
||||||
@@ -182,6 +234,7 @@ class ConversationManager:
|
|||||||
search_query (str): 搜索查询字符串, 可选
|
search_query (str): 搜索查询字符串, 可选
|
||||||
Returns:
|
Returns:
|
||||||
conversations (list[Conversation]): 对话对象列表
|
conversations (list[Conversation]): 对话对象列表
|
||||||
|
|
||||||
"""
|
"""
|
||||||
convs, cnt = await self.db.get_filtered_conversations(
|
convs, cnt = await self.db.get_filtered_conversations(
|
||||||
page=page,
|
page=page,
|
||||||
@@ -203,13 +256,14 @@ class ConversationManager:
|
|||||||
history: list[dict] | None = None,
|
history: list[dict] | None = None,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
persona_id: str | None = None,
|
persona_id: str | None = None,
|
||||||
):
|
) -> None:
|
||||||
"""更新会话的对话
|
"""更新会话的对话.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
# 如果没有提供 conversation_id,则获取当前的
|
# 如果没有提供 conversation_id,则获取当前的
|
||||||
@@ -223,16 +277,20 @@ class ConversationManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def update_conversation_title(
|
async def update_conversation_title(
|
||||||
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
|
self,
|
||||||
):
|
unified_msg_origin: str,
|
||||||
"""更新会话的对话标题
|
title: str,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""更新会话的对话标题.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
title (str): 对话标题
|
title (str): 对话标题
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
Deprecated:
|
Deprecated:
|
||||||
Use `update_conversation` with `title` parameter instead.
|
Use `update_conversation` with `title` parameter instead.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
await self.update_conversation(
|
await self.update_conversation(
|
||||||
unified_msg_origin=unified_msg_origin,
|
unified_msg_origin=unified_msg_origin,
|
||||||
@@ -245,15 +303,16 @@ class ConversationManager:
|
|||||||
unified_msg_origin: str,
|
unified_msg_origin: str,
|
||||||
persona_id: str,
|
persona_id: str,
|
||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
):
|
) -> None:
|
||||||
"""更新会话的对话 Persona ID
|
"""更新会话的对话 Persona ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
persona_id (str): 对话 Persona ID
|
persona_id (str): 对话 Persona ID
|
||||||
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
Deprecated:
|
Deprecated:
|
||||||
Use `update_conversation` with `persona_id` parameter instead.
|
Use `update_conversation` with `persona_id` parameter instead.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
await self.update_conversation(
|
await self.update_conversation(
|
||||||
unified_msg_origin=unified_msg_origin,
|
unified_msg_origin=unified_msg_origin,
|
||||||
@@ -261,40 +320,85 @@ class ConversationManager:
|
|||||||
persona_id=persona_id,
|
persona_id=persona_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def add_message_pair(
|
||||||
|
self,
|
||||||
|
cid: str,
|
||||||
|
user_message: UserMessageSegment | dict,
|
||||||
|
assistant_message: AssistantMessageSegment | dict,
|
||||||
|
) -> None:
|
||||||
|
"""Add a user-assistant message pair to the conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cid (str): Conversation ID
|
||||||
|
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
|
||||||
|
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the conversation with the given ID is not found
|
||||||
|
"""
|
||||||
|
conv = await self.db.get_conversation_by_id(cid=cid)
|
||||||
|
if not conv:
|
||||||
|
raise Exception(f"Conversation with id {cid} not found")
|
||||||
|
history = conv.content or []
|
||||||
|
if isinstance(user_message, UserMessageSegment):
|
||||||
|
user_msg_dict = user_message.model_dump()
|
||||||
|
else:
|
||||||
|
user_msg_dict = user_message
|
||||||
|
if isinstance(assistant_message, AssistantMessageSegment):
|
||||||
|
assistant_msg_dict = assistant_message.model_dump()
|
||||||
|
else:
|
||||||
|
assistant_msg_dict = assistant_message
|
||||||
|
history.append(user_msg_dict)
|
||||||
|
history.append(assistant_msg_dict)
|
||||||
|
await self.db.update_conversation(
|
||||||
|
cid=cid,
|
||||||
|
content=history,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_human_readable_context(
|
async def get_human_readable_context(
|
||||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
self,
|
||||||
):
|
unified_msg_origin: str,
|
||||||
"""获取人类可读的上下文
|
conversation_id: str,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 10,
|
||||||
|
) -> tuple[list[str], int]:
|
||||||
|
"""获取人类可读的上下文.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
page (int): 页码
|
page (int): 页码
|
||||||
page_size (int): 每页大小
|
page_size (int): 每页大小
|
||||||
|
|
||||||
"""
|
"""
|
||||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||||
|
if not conversation:
|
||||||
|
return [], 0
|
||||||
history = json.loads(conversation.history)
|
history = json.loads(conversation.history)
|
||||||
|
|
||||||
contexts = []
|
# contexts_groups 存放按顺序的段落(每个段落是一个 str 列表),
|
||||||
temp_contexts = []
|
# 之后会被展平成一个扁平的 str 列表返回。
|
||||||
|
contexts_groups: list[list[str]] = []
|
||||||
|
temp_contexts: list[str] = []
|
||||||
for record in history:
|
for record in history:
|
||||||
if record["role"] == "user":
|
if record["role"] == "user":
|
||||||
temp_contexts.append(f"User: {record['content']}")
|
temp_contexts.append(f"User: {record['content']}")
|
||||||
elif record["role"] == "assistant":
|
elif record["role"] == "assistant":
|
||||||
if "content" in record and record["content"]:
|
if record.get("content"):
|
||||||
temp_contexts.append(f"Assistant: {record['content']}")
|
temp_contexts.append(f"Assistant: {record['content']}")
|
||||||
elif "tool_calls" in record:
|
elif "tool_calls" in record:
|
||||||
tool_calls_str = json.dumps(
|
tool_calls_str = json.dumps(
|
||||||
record["tool_calls"], ensure_ascii=False
|
record["tool_calls"],
|
||||||
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||||||
else:
|
else:
|
||||||
temp_contexts.append("Assistant: [未知的内容]")
|
temp_contexts.append("Assistant: [未知的内容]")
|
||||||
contexts.insert(0, temp_contexts)
|
contexts_groups.insert(0, temp_contexts)
|
||||||
temp_contexts = []
|
temp_contexts = []
|
||||||
|
|
||||||
# 展平 contexts 列表
|
# 展平分组后的 contexts 列表为单层字符串列表
|
||||||
contexts = [item for sublist in contexts for item in sublist]
|
contexts = [item for sublist in contexts_groups for item in sublist]
|
||||||
|
|
||||||
# 计算分页
|
# 计算分页
|
||||||
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
|
||||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
|
||||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
|
|
||||||
@@ -9,42 +9,44 @@ Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、
|
|||||||
3. 执行启动完成事件钩子
|
3. 执行启动完成事件钩子
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
import os
|
import os
|
||||||
from .event_bus import EventBus
|
import threading
|
||||||
from . import astrbot_config, html_renderer
|
import time
|
||||||
|
import traceback
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from typing import List
|
|
||||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
from astrbot.core import LogBroker, logger, sp
|
||||||
from astrbot.core.star import PluginManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
from astrbot.core.platform.manager import PlatformManager
|
|
||||||
from astrbot.core.star.context import Context
|
|
||||||
from astrbot.core.persona_mgr import PersonaManager
|
|
||||||
from astrbot.core.provider.manager import ProviderManager
|
|
||||||
from astrbot.core import LogBroker
|
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from astrbot.core.updator import AstrBotUpdator
|
|
||||||
from astrbot.core import logger, sp
|
|
||||||
from astrbot.core.config.default import VERSION
|
from astrbot.core.config.default import VERSION
|
||||||
from astrbot.core.conversation_mgr import ConversationManager
|
from astrbot.core.conversation_mgr import ConversationManager
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||||
|
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||||
|
from astrbot.core.persona_mgr import PersonaManager
|
||||||
|
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
|
||||||
|
from astrbot.core.platform.manager import PlatformManager
|
||||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.provider.manager import ProviderManager
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star import PluginManager
|
||||||
from astrbot.core.star.star_handler import star_map
|
from astrbot.core.star.context import Context
|
||||||
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||||
|
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||||
|
from astrbot.core.updator import AstrBotUpdator
|
||||||
|
|
||||||
|
from . import astrbot_config, html_renderer
|
||||||
|
from .event_bus import EventBus
|
||||||
|
|
||||||
|
|
||||||
class AstrBotCoreLifecycle:
|
class AstrBotCoreLifecycle:
|
||||||
"""
|
"""AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
|
||||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
|
||||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||||
EventBus 等。
|
EventBus 等。
|
||||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None:
|
||||||
self.log_broker = log_broker # 初始化日志代理
|
self.log_broker = log_broker # 初始化日志代理
|
||||||
self.astrbot_config = astrbot_config # 初始化配置
|
self.astrbot_config = astrbot_config # 初始化配置
|
||||||
self.db = db # 初始化数据库
|
self.db = db # 初始化数据库
|
||||||
@@ -68,11 +70,11 @@ class AstrBotCoreLifecycle:
|
|||||||
del os.environ["no_proxy"]
|
del os.environ["no_proxy"]
|
||||||
logger.debug("HTTP proxy cleared")
|
logger.debug("HTTP proxy cleared")
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self) -> None:
|
||||||
"""
|
"""初始化 AstrBot 核心生命周期管理类.
|
||||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||||
|
"""
|
||||||
# 初始化日志代理
|
# 初始化日志代理
|
||||||
logger.info("AstrBot v" + VERSION)
|
logger.info("AstrBot v" + VERSION)
|
||||||
if os.environ.get("TESTING", ""):
|
if os.environ.get("TESTING", ""):
|
||||||
@@ -84,11 +86,23 @@ class AstrBotCoreLifecycle:
|
|||||||
|
|
||||||
await html_renderer.initialize()
|
await html_renderer.initialize()
|
||||||
|
|
||||||
|
# 初始化 UMOP 配置路由器
|
||||||
|
self.umop_config_router = UmopConfigRouter(sp=sp)
|
||||||
|
|
||||||
# 初始化 AstrBot 配置管理器
|
# 初始化 AstrBot 配置管理器
|
||||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||||
default_config=self.astrbot_config, sp=sp
|
default_config=self.astrbot_config,
|
||||||
|
ucr=self.umop_config_router,
|
||||||
|
sp=sp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 4.5 to 4.6 migration for umop_config_router
|
||||||
|
try:
|
||||||
|
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
# 初始化事件队列
|
# 初始化事件队列
|
||||||
self.event_queue = Queue()
|
self.event_queue = Queue()
|
||||||
|
|
||||||
@@ -98,7 +112,9 @@ class AstrBotCoreLifecycle:
|
|||||||
|
|
||||||
# 初始化供应商管理器
|
# 初始化供应商管理器
|
||||||
self.provider_manager = ProviderManager(
|
self.provider_manager = ProviderManager(
|
||||||
self.astrbot_config_mgr, self.db, self.persona_mgr
|
self.astrbot_config_mgr,
|
||||||
|
self.db,
|
||||||
|
self.persona_mgr,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化平台管理器
|
# 初始化平台管理器
|
||||||
@@ -110,6 +126,9 @@ class AstrBotCoreLifecycle:
|
|||||||
# 初始化平台消息历史管理器
|
# 初始化平台消息历史管理器
|
||||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||||
|
|
||||||
|
# 初始化知识库管理器
|
||||||
|
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||||
|
|
||||||
# 初始化提供给插件的上下文
|
# 初始化提供给插件的上下文
|
||||||
self.star_context = Context(
|
self.star_context = Context(
|
||||||
self.event_queue,
|
self.event_queue,
|
||||||
@@ -121,6 +140,7 @@ class AstrBotCoreLifecycle:
|
|||||||
self.platform_message_history_manager,
|
self.platform_message_history_manager,
|
||||||
self.persona_mgr,
|
self.persona_mgr,
|
||||||
self.astrbot_config_mgr,
|
self.astrbot_config_mgr,
|
||||||
|
self.kb_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化插件管理器
|
# 初始化插件管理器
|
||||||
@@ -132,8 +152,9 @@ class AstrBotCoreLifecycle:
|
|||||||
# 根据配置实例化各个 Provider
|
# 根据配置实例化各个 Provider
|
||||||
await self.provider_manager.initialize()
|
await self.provider_manager.initialize()
|
||||||
|
|
||||||
# 初始化消息事件流水线调度器
|
await self.kb_manager.initialize()
|
||||||
|
|
||||||
|
# 初始化消息事件流水线调度器
|
||||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||||
|
|
||||||
# 初始化更新器
|
# 初始化更新器
|
||||||
@@ -141,14 +162,16 @@ class AstrBotCoreLifecycle:
|
|||||||
|
|
||||||
# 初始化事件总线
|
# 初始化事件总线
|
||||||
self.event_bus = EventBus(
|
self.event_bus = EventBus(
|
||||||
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
|
self.event_queue,
|
||||||
|
self.pipeline_scheduler_mapping,
|
||||||
|
self.astrbot_config_mgr,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录启动时间
|
# 记录启动时间
|
||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
# 初始化当前任务列表
|
# 初始化当前任务列表
|
||||||
self.curr_tasks: List[asyncio.Task] = []
|
self.curr_tasks: list[asyncio.Task] = []
|
||||||
|
|
||||||
# 根据配置实例化各个平台适配器
|
# 根据配置实例化各个平台适配器
|
||||||
await self.platform_manager.initialize()
|
await self.platform_manager.initialize()
|
||||||
@@ -156,13 +179,13 @@ class AstrBotCoreLifecycle:
|
|||||||
# 初始化关闭控制面板的事件
|
# 初始化关闭控制面板的事件
|
||||||
self.dashboard_shutdown_event = asyncio.Event()
|
self.dashboard_shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
def _load(self):
|
def _load(self) -> None:
|
||||||
"""加载事件总线和任务并初始化"""
|
"""加载事件总线和任务并初始化."""
|
||||||
|
|
||||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||||
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
||||||
event_bus_task = asyncio.create_task(
|
event_bus_task = asyncio.create_task(
|
||||||
self.event_bus.dispatch(), name="event_bus"
|
self.event_bus.dispatch(),
|
||||||
|
name="event_bus",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||||
@@ -173,16 +196,17 @@ class AstrBotCoreLifecycle:
|
|||||||
tasks_ = [event_bus_task, *extra_tasks]
|
tasks_ = [event_bus_task, *extra_tasks]
|
||||||
for task in tasks_:
|
for task in tasks_:
|
||||||
self.curr_tasks.append(
|
self.curr_tasks.append(
|
||||||
asyncio.create_task(self._task_wrapper(task), name=task.get_name())
|
asyncio.create_task(self._task_wrapper(task), name=task.get_name()),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.start_time = int(time.time())
|
self.start_time = int(time.time())
|
||||||
|
|
||||||
async def _task_wrapper(self, task: asyncio.Task):
|
async def _task_wrapper(self, task: asyncio.Task) -> None:
|
||||||
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
|
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (asyncio.Task): 要执行的异步任务
|
task (asyncio.Task): 要执行的异步任务
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
@@ -195,19 +219,22 @@ class AstrBotCoreLifecycle:
|
|||||||
logger.error(f"| {line}")
|
logger.error(f"| {line}")
|
||||||
logger.error("-------")
|
logger.error("-------")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self) -> None:
|
||||||
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
|
"""启动 AstrBot 核心生命周期管理类.
|
||||||
|
|
||||||
|
用load加载事件总线和任务并初始化, 执行启动完成事件钩子
|
||||||
|
"""
|
||||||
self._load()
|
self._load()
|
||||||
logger.info("AstrBot 启动完成。")
|
logger.info("AstrBot 启动完成。")
|
||||||
|
|
||||||
# 执行启动完成事件钩子
|
# 执行启动完成事件钩子
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
EventType.OnAstrBotLoadedEvent
|
EventType.OnAstrBotLoadedEvent,
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||||
)
|
)
|
||||||
await handler.handler()
|
await handler.handler()
|
||||||
except BaseException:
|
except BaseException:
|
||||||
@@ -216,8 +243,8 @@ class AstrBotCoreLifecycle:
|
|||||||
# 同时运行curr_tasks中的所有任务
|
# 同时运行curr_tasks中的所有任务
|
||||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self) -> None:
|
||||||
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
|
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
|
||||||
# 请求停止所有正在运行的异步任务
|
# 请求停止所有正在运行的异步任务
|
||||||
for task in self.curr_tasks:
|
for task in self.curr_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -228,11 +255,12 @@ class AstrBotCoreLifecycle:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(traceback.format_exc())
|
logger.warning(traceback.format_exc())
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
|
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。",
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.provider_manager.terminate()
|
await self.provider_manager.terminate()
|
||||||
await self.platform_manager.terminate()
|
await self.platform_manager.terminate()
|
||||||
|
await self.kb_manager.terminate()
|
||||||
self.dashboard_shutdown_event.set()
|
self.dashboard_shutdown_event.set()
|
||||||
|
|
||||||
# 再次遍历curr_tasks等待每个任务真正结束
|
# 再次遍历curr_tasks等待每个任务真正结束
|
||||||
@@ -244,16 +272,19 @@ class AstrBotCoreLifecycle:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||||
|
|
||||||
async def restart(self):
|
async def restart(self) -> None:
|
||||||
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||||
await self.provider_manager.terminate()
|
await self.provider_manager.terminate()
|
||||||
await self.platform_manager.terminate()
|
await self.platform_manager.terminate()
|
||||||
|
await self.kb_manager.terminate()
|
||||||
self.dashboard_shutdown_event.set()
|
self.dashboard_shutdown_event.set()
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
target=self.astrbot_updator._reboot,
|
||||||
|
name="restart",
|
||||||
|
daemon=True,
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
def load_platform(self) -> List[asyncio.Task]:
|
def load_platform(self) -> list[asyncio.Task]:
|
||||||
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||||
tasks = []
|
tasks = []
|
||||||
platform_insts = self.platform_manager.get_insts()
|
platform_insts = self.platform_manager.get_insts()
|
||||||
@@ -262,36 +293,38 @@ class AstrBotCoreLifecycle:
|
|||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
platform_inst.run(),
|
platform_inst.run(),
|
||||||
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
|
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
|
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
|
||||||
"""加载消息事件流水线调度器
|
"""加载消息事件流水线调度器.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||||
|
|
||||||
"""
|
"""
|
||||||
mapping = {}
|
mapping = {}
|
||||||
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
|
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
|
||||||
scheduler = PipelineScheduler(
|
scheduler = PipelineScheduler(
|
||||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
PipelineContext(ab_config, self.plugin_manager, conf_id),
|
||||||
)
|
)
|
||||||
await scheduler.initialize()
|
await scheduler.initialize()
|
||||||
mapping[conf_id] = scheduler
|
mapping[conf_id] = scheduler
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
async def reload_pipeline_scheduler(self, conf_id: str):
|
async def reload_pipeline_scheduler(self, conf_id: str) -> None:
|
||||||
"""重新加载消息事件流水线调度器
|
"""重新加载消息事件流水线调度器.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
|
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
|
||||||
if not ab_config:
|
if not ab_config:
|
||||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||||
scheduler = PipelineScheduler(
|
scheduler = PipelineScheduler(
|
||||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
PipelineContext(ab_config, self.plugin_manager, conf_id),
|
||||||
)
|
)
|
||||||
await scheduler.initialize()
|
await scheduler.initialize()
|
||||||
self.pipeline_scheduler_mapping[conf_id] = scheduler
|
self.pipeline_scheduler_mapping[conf_id] = scheduler
|
||||||
|
|||||||
@@ -1,27 +1,27 @@
|
|||||||
import abc
|
import abc
|
||||||
import datetime
|
import datetime
|
||||||
import typing as T
|
import typing as T
|
||||||
from deprecated import deprecated
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from astrbot.core.db.po import (
|
|
||||||
Stats,
|
|
||||||
PlatformStat,
|
|
||||||
ConversationV2,
|
|
||||||
PlatformMessageHistory,
|
|
||||||
Attachment,
|
|
||||||
Persona,
|
|
||||||
Preference,
|
|
||||||
)
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from astrbot.core.db.po import (
|
||||||
|
Attachment,
|
||||||
|
ConversationV2,
|
||||||
|
Persona,
|
||||||
|
PlatformMessageHistory,
|
||||||
|
PlatformStat,
|
||||||
|
Preference,
|
||||||
|
Stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseDatabase(abc.ABC):
|
class BaseDatabase(abc.ABC):
|
||||||
"""
|
"""数据库基类"""
|
||||||
数据库基类
|
|
||||||
"""
|
|
||||||
|
|
||||||
DATABASE_URL = ""
|
DATABASE_URL = ""
|
||||||
|
|
||||||
@@ -32,12 +32,13 @@ class BaseDatabase(abc.ABC):
|
|||||||
future=True,
|
future=True,
|
||||||
)
|
)
|
||||||
self.AsyncSessionLocal = sessionmaker(
|
self.AsyncSessionLocal = sessionmaker(
|
||||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
self.engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化数据库连接"""
|
"""初始化数据库连接"""
|
||||||
pass
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
||||||
@@ -91,7 +92,9 @@ class BaseDatabase(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_conversations(
|
async def get_conversations(
|
||||||
self, user_id: str | None = None, platform_id: str | None = None
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
platform_id: str | None = None,
|
||||||
) -> list[ConversationV2]:
|
) -> list[ConversationV2]:
|
||||||
"""Get all conversations for a specific user and platform_id(optional).
|
"""Get all conversations for a specific user and platform_id(optional).
|
||||||
|
|
||||||
@@ -106,7 +109,9 @@ class BaseDatabase(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_all_conversations(
|
async def get_all_conversations(
|
||||||
self, page: int = 1, page_size: int = 20
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
) -> list[ConversationV2]:
|
) -> list[ConversationV2]:
|
||||||
"""Get all conversations with pagination."""
|
"""Get all conversations with pagination."""
|
||||||
...
|
...
|
||||||
@@ -173,7 +178,10 @@ class BaseDatabase(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def delete_platform_message_offset(
|
async def delete_platform_message_offset(
|
||||||
self, platform_id: str, user_id: str, offset_sec: int = 86400
|
self,
|
||||||
|
platform_id: str,
|
||||||
|
user_id: str,
|
||||||
|
offset_sec: int = 86400,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete platform message history records older than the specified offset."""
|
"""Delete platform message history records older than the specified offset."""
|
||||||
...
|
...
|
||||||
@@ -243,7 +251,11 @@ class BaseDatabase(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def insert_preference_or_update(
|
async def insert_preference_or_update(
|
||||||
self, scope: str, scope_id: str, key: str, value: dict
|
self,
|
||||||
|
scope: str,
|
||||||
|
scope_id: str,
|
||||||
|
key: str,
|
||||||
|
value: dict,
|
||||||
) -> Preference:
|
) -> Preference:
|
||||||
"""Insert a new preference record."""
|
"""Insert a new preference record."""
|
||||||
...
|
...
|
||||||
@@ -255,7 +267,10 @@ class BaseDatabase(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_preferences(
|
async def get_preferences(
|
||||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
self,
|
||||||
|
scope: str,
|
||||||
|
scope_id: str | None = None,
|
||||||
|
key: str | None = None,
|
||||||
) -> list[Preference]:
|
) -> list[Preference]:
|
||||||
"""Get all preferences for a specific scope ID or key."""
|
"""Get all preferences for a specific scope ID or key."""
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -1,27 +1,33 @@
|
|||||||
import os
|
import os
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
||||||
from astrbot.core.db import BaseDatabase
|
|
||||||
from astrbot.core.config import AstrBotConfig
|
|
||||||
from astrbot.api import logger, sp
|
from astrbot.api import logger, sp
|
||||||
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
from .migra_3_to_4 import (
|
from .migra_3_to_4 import (
|
||||||
migration_conversation_table,
|
migration_conversation_table,
|
||||||
migration_platform_table,
|
|
||||||
migration_webchat_data,
|
|
||||||
migration_persona_data,
|
migration_persona_data,
|
||||||
|
migration_platform_table,
|
||||||
migration_preferences,
|
migration_preferences,
|
||||||
|
migration_webchat_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||||
"""
|
"""检查是否需要进行数据库迁移
|
||||||
检查是否需要进行数据库迁移
|
|
||||||
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。
|
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。
|
||||||
"""
|
"""
|
||||||
data_v3_exists = os.path.exists(get_astrbot_data_path())
|
# 仅当 data 目录下存在旧版本数据(data_v3.db 文件)时才考虑迁移
|
||||||
if not data_v3_exists:
|
data_dir = get_astrbot_data_path()
|
||||||
|
data_v3_db = os.path.join(data_dir, "data_v3.db")
|
||||||
|
|
||||||
|
if not os.path.exists(data_v3_db):
|
||||||
return False
|
return False
|
||||||
migration_done = await db_helper.get_preference(
|
migration_done = await db_helper.get_preference(
|
||||||
"global", "global", "migration_done_v4"
|
"global",
|
||||||
|
"global",
|
||||||
|
"migration_done_v4",
|
||||||
)
|
)
|
||||||
if migration_done:
|
if migration_done:
|
||||||
return False
|
return False
|
||||||
@@ -32,9 +38,8 @@ async def do_migration_v4(
|
|||||||
db_helper: BaseDatabase,
|
db_helper: BaseDatabase,
|
||||||
platform_id_map: dict[str, dict[str, str]],
|
platform_id_map: dict[str, dict[str, str]],
|
||||||
astrbot_config: AstrBotConfig,
|
astrbot_config: AstrBotConfig,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""执行数据库迁移
|
||||||
执行数据库迁移
|
|
||||||
迁移旧的 webchat_conversation 表到新的 conversation 表。
|
迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||||
迁移旧的 platform 到新的 platform_stats 表。
|
迁移旧的 platform 到新的 platform_stats 表。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import json
|
|
||||||
import datetime
|
import datetime
|
||||||
from .. import BaseDatabase
|
import json
|
||||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
|
||||||
from .shared_preferences_v3 import sp as sp_v3
|
from sqlalchemy import text
|
||||||
from astrbot.core.config.default import DB_PATH
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from astrbot.api import logger, sp
|
from astrbot.api import logger, sp
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.config.default import DB_PATH
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
|
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
|
||||||
from sqlalchemy import text
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
|
||||||
|
from .. import BaseDatabase
|
||||||
|
from .shared_preferences_v3 import sp as sp_v3
|
||||||
|
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||||
|
|
||||||
"""
|
"""
|
||||||
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
|
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||||
@@ -18,7 +21,8 @@ from sqlalchemy import text
|
|||||||
|
|
||||||
|
|
||||||
def get_platform_id(
|
def get_platform_id(
|
||||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
platform_id_map: dict[str, dict[str, str]],
|
||||||
|
old_platform_name: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
return platform_id_map.get(
|
return platform_id_map.get(
|
||||||
old_platform_name,
|
old_platform_name,
|
||||||
@@ -27,7 +31,8 @@ def get_platform_id(
|
|||||||
|
|
||||||
|
|
||||||
def get_platform_type(
|
def get_platform_type(
|
||||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
platform_id_map: dict[str, dict[str, str]],
|
||||||
|
old_platform_name: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
return platform_id_map.get(
|
return platform_id_map.get(
|
||||||
old_platform_name,
|
old_platform_name,
|
||||||
@@ -36,13 +41,15 @@ def get_platform_type(
|
|||||||
|
|
||||||
|
|
||||||
async def migration_conversation_table(
|
async def migration_conversation_table(
|
||||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
db_helper: BaseDatabase,
|
||||||
|
platform_id_map: dict[str, dict[str, str]],
|
||||||
):
|
):
|
||||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||||
)
|
)
|
||||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||||
page=1, page_size=10000000
|
page=1,
|
||||||
|
page_size=10000000,
|
||||||
)
|
)
|
||||||
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
|
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
|
||||||
|
|
||||||
@@ -61,13 +68,14 @@ async def migration_conversation_table(
|
|||||||
)
|
)
|
||||||
if not conv:
|
if not conv:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||||
)
|
)
|
||||||
if ":" not in conv.user_id:
|
if ":" not in conv.user_id:
|
||||||
continue
|
continue
|
||||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||||
platform_id = get_platform_id(
|
platform_id = get_platform_id(
|
||||||
platform_id_map, session.platform_name
|
platform_id_map,
|
||||||
|
session.platform_name,
|
||||||
)
|
)
|
||||||
session.platform_id = platform_id # 更新平台名称为新的 ID
|
session.platform_id = platform_id # 更新平台名称为新的 ID
|
||||||
conv_v2 = ConversationV2(
|
conv_v2 = ConversationV2(
|
||||||
@@ -90,10 +98,11 @@ async def migration_conversation_table(
|
|||||||
|
|
||||||
|
|
||||||
async def migration_platform_table(
|
async def migration_platform_table(
|
||||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
db_helper: BaseDatabase,
|
||||||
|
platform_id_map: dict[str, dict[str, str]],
|
||||||
):
|
):
|
||||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||||
)
|
)
|
||||||
secs_from_2023_4_10_to_now = (
|
secs_from_2023_4_10_to_now = (
|
||||||
datetime.datetime.now(datetime.timezone.utc)
|
datetime.datetime.now(datetime.timezone.utc)
|
||||||
@@ -134,10 +143,12 @@ async def migration_platform_table(
|
|||||||
if cnt == 0:
|
if cnt == 0:
|
||||||
continue
|
continue
|
||||||
platform_id = get_platform_id(
|
platform_id = get_platform_id(
|
||||||
platform_id_map, platform_stats_v3[idx].name
|
platform_id_map,
|
||||||
|
platform_stats_v3[idx].name,
|
||||||
)
|
)
|
||||||
platform_type = get_platform_type(
|
platform_type = get_platform_type(
|
||||||
platform_id_map, platform_stats_v3[idx].name
|
platform_id_map,
|
||||||
|
platform_stats_v3[idx].name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await dbsession.execute(
|
await dbsession.execute(
|
||||||
@@ -149,7 +160,8 @@ async def migration_platform_table(
|
|||||||
"""),
|
"""),
|
||||||
{
|
{
|
||||||
"timestamp": datetime.datetime.fromtimestamp(
|
"timestamp": datetime.datetime.fromtimestamp(
|
||||||
bucket_end, tz=datetime.timezone.utc
|
bucket_end,
|
||||||
|
tz=datetime.timezone.utc,
|
||||||
),
|
),
|
||||||
"platform_id": platform_id,
|
"platform_id": platform_id,
|
||||||
"platform_type": platform_type,
|
"platform_type": platform_type,
|
||||||
@@ -165,14 +177,16 @@ async def migration_platform_table(
|
|||||||
|
|
||||||
|
|
||||||
async def migration_webchat_data(
|
async def migration_webchat_data(
|
||||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
db_helper: BaseDatabase,
|
||||||
|
platform_id_map: dict[str, dict[str, str]],
|
||||||
):
|
):
|
||||||
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
|
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
|
||||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||||
)
|
)
|
||||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||||
page=1, page_size=10000000
|
page=1,
|
||||||
|
page_size=10000000,
|
||||||
)
|
)
|
||||||
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
|
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
|
||||||
|
|
||||||
@@ -191,7 +205,7 @@ async def migration_webchat_data(
|
|||||||
)
|
)
|
||||||
if not conv:
|
if not conv:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||||
)
|
)
|
||||||
if ":" in conv.user_id:
|
if ":" in conv.user_id:
|
||||||
continue
|
continue
|
||||||
@@ -218,10 +232,10 @@ async def migration_webchat_data(
|
|||||||
|
|
||||||
|
|
||||||
async def migration_persona_data(
|
async def migration_persona_data(
|
||||||
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
|
db_helper: BaseDatabase,
|
||||||
|
astrbot_config: AstrBotConfig,
|
||||||
):
|
):
|
||||||
"""
|
"""迁移 Persona 数据到新的表中。
|
||||||
迁移 Persona 数据到新的表中。
|
|
||||||
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
||||||
"""
|
"""
|
||||||
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
|
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
|
||||||
@@ -236,14 +250,15 @@ async def migration_persona_data(
|
|||||||
try:
|
try:
|
||||||
begin_dialogs = persona.get("begin_dialogs", [])
|
begin_dialogs = persona.get("begin_dialogs", [])
|
||||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||||
mood_prompt = ""
|
parts = []
|
||||||
user_turn = True
|
user_turn = True
|
||||||
for mood_dialog in mood_imitation_dialogs:
|
for mood_dialog in mood_imitation_dialogs:
|
||||||
if user_turn:
|
if user_turn:
|
||||||
mood_prompt += f"A: {mood_dialog}\n"
|
parts.append(f"A: {mood_dialog}\n")
|
||||||
else:
|
else:
|
||||||
mood_prompt += f"B: {mood_dialog}\n"
|
parts.append(f"B: {mood_dialog}\n")
|
||||||
user_turn = not user_turn
|
user_turn = not user_turn
|
||||||
|
mood_prompt = "".join(parts)
|
||||||
system_prompt = persona.get("prompt", "")
|
system_prompt = persona.get("prompt", "")
|
||||||
if mood_prompt:
|
if mood_prompt:
|
||||||
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
|
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
|
||||||
@@ -253,14 +268,15 @@ async def migration_persona_data(
|
|||||||
begin_dialogs=begin_dialogs,
|
begin_dialogs=begin_dialogs,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
|
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析 Persona 配置失败:{e}")
|
logger.error(f"解析 Persona 配置失败:{e}")
|
||||||
|
|
||||||
|
|
||||||
async def migration_preferences(
|
async def migration_preferences(
|
||||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
db_helper: BaseDatabase,
|
||||||
|
platform_id_map: dict[str, dict[str, str]],
|
||||||
):
|
):
|
||||||
# 1. global scope migration
|
# 1. global scope migration
|
||||||
keys = [
|
keys = [
|
||||||
@@ -329,10 +345,13 @@ async def migration_preferences(
|
|||||||
|
|
||||||
for provider_type, provider_id in perf.items():
|
for provider_type, provider_id in perf.items():
|
||||||
await sp.put_async(
|
await sp.put_async(
|
||||||
"umo", str(session), f"provider_perf_{provider_type}", provider_id
|
"umo",
|
||||||
|
str(session),
|
||||||
|
f"provider_perf_{provider_type}",
|
||||||
|
provider_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
|
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
|
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
|
||||||
|
|||||||
44
astrbot/core/db/migration/migra_45_to_46.py
Normal file
44
astrbot/core/db/migration/migra_45_to_46.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from astrbot.api import logger, sp
|
||||||
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
|
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
|
||||||
|
abconf_data = acm.abconf_data
|
||||||
|
|
||||||
|
if not isinstance(abconf_data, dict):
|
||||||
|
# should be unreachable
|
||||||
|
logger.warning(
|
||||||
|
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果任何一项带有 umop,则说明需要迁移
|
||||||
|
need_migration = False
|
||||||
|
for conf_id, conf_info in abconf_data.items():
|
||||||
|
if isinstance(conf_info, dict) and "umop" in conf_info:
|
||||||
|
need_migration = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not need_migration:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting migration from version 4.5 to 4.6")
|
||||||
|
|
||||||
|
# extract umo->conf_id mapping
|
||||||
|
umo_to_conf_id = {}
|
||||||
|
for conf_id, conf_info in abconf_data.items():
|
||||||
|
if isinstance(conf_info, dict) and "umop" in conf_info:
|
||||||
|
umop_ls = conf_info.pop("umop")
|
||||||
|
if not isinstance(umop_ls, list):
|
||||||
|
continue
|
||||||
|
for umo in umop_ls:
|
||||||
|
if isinstance(umo, str) and umo not in umo_to_conf_id:
|
||||||
|
umo_to_conf_id[umo] = conf_id
|
||||||
|
|
||||||
|
# update the abconf data
|
||||||
|
await sp.global_put("abconf_mapping", abconf_data)
|
||||||
|
# update the umop config router
|
||||||
|
await ucr.update_routing_data(umo_to_conf_id)
|
||||||
|
|
||||||
|
logger.info("Migration from version 45 to 46 completed successfully")
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
_VT = TypeVar("_VT")
|
_VT = TypeVar("_VT")
|
||||||
@@ -16,7 +17,7 @@ class SharedPreferences:
|
|||||||
def _load_preferences(self):
|
def _load_preferences(self):
|
||||||
if os.path.exists(self.path):
|
if os.path.exists(self.path):
|
||||||
try:
|
try:
|
||||||
with open(self.path, "r") as f:
|
with open(self.path) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
os.remove(self.path)
|
os.remove(self.path)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
from astrbot.core.db.po import Platform, Stats
|
|
||||||
from typing import Tuple, List, Dict, Any
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from astrbot.core.db.po import Platform, Stats
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -94,7 +95,7 @@ class SQLiteDatabase:
|
|||||||
c.execute(
|
c.execute(
|
||||||
"""
|
"""
|
||||||
PRAGMA table_info(webchat_conversation)
|
PRAGMA table_info(webchat_conversation)
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
res = c.fetchall()
|
res = c.fetchall()
|
||||||
has_title = False
|
has_title = False
|
||||||
@@ -108,14 +109,14 @@ class SQLiteDatabase:
|
|||||||
c.execute(
|
c.execute(
|
||||||
"""
|
"""
|
||||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
if not has_persona_id:
|
if not has_persona_id:
|
||||||
c.execute(
|
c.execute(
|
||||||
"""
|
"""
|
||||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ class SQLiteDatabase:
|
|||||||
conn.text_factory = str
|
conn.text_factory = str
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def _exec_sql(self, sql: str, params: Tuple = None):
|
def _exec_sql(self, sql: str, params: tuple = None):
|
||||||
conn = self.conn
|
conn = self.conn
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
@@ -174,7 +175,7 @@ class SQLiteDatabase:
|
|||||||
"""
|
"""
|
||||||
SELECT * FROM platform
|
SELECT * FROM platform
|
||||||
"""
|
"""
|
||||||
+ where_clause
|
+ where_clause,
|
||||||
)
|
)
|
||||||
|
|
||||||
platform = []
|
platform = []
|
||||||
@@ -194,7 +195,7 @@ class SQLiteDatabase:
|
|||||||
c.execute(
|
c.execute(
|
||||||
"""
|
"""
|
||||||
SELECT SUM(count) FROM platform
|
SELECT SUM(count) FROM platform
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
res = c.fetchone()
|
res = c.fetchone()
|
||||||
c.close()
|
c.close()
|
||||||
@@ -214,7 +215,7 @@ class SQLiteDatabase:
|
|||||||
SELECT name, SUM(count), timestamp FROM platform
|
SELECT name, SUM(count), timestamp FROM platform
|
||||||
"""
|
"""
|
||||||
+ where_clause
|
+ where_clause
|
||||||
+ " GROUP BY name"
|
+ " GROUP BY name",
|
||||||
)
|
)
|
||||||
|
|
||||||
platform = []
|
platform = []
|
||||||
@@ -242,7 +243,7 @@ class SQLiteDatabase:
|
|||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
if not res:
|
if not res:
|
||||||
return
|
return None
|
||||||
|
|
||||||
return Conversation(*res)
|
return Conversation(*res)
|
||||||
|
|
||||||
@@ -257,7 +258,7 @@ class SQLiteDatabase:
|
|||||||
(user_id, cid, history, updated_at, created_at),
|
(user_id, cid, history, updated_at, created_at),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conversations(self, user_id: str) -> Tuple:
|
def get_conversations(self, user_id: str) -> tuple:
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
@@ -280,7 +281,7 @@ class SQLiteDatabase:
|
|||||||
title = row[3]
|
title = row[3]
|
||||||
persona_id = row[4]
|
persona_id = row[4]
|
||||||
conversations.append(
|
conversations.append(
|
||||||
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
Conversation("", cid, "[]", created_at, updated_at, title, persona_id),
|
||||||
)
|
)
|
||||||
return conversations
|
return conversations
|
||||||
|
|
||||||
@@ -319,8 +320,10 @@ class SQLiteDatabase:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_all_conversations(
|
def get_all_conversations(
|
||||||
self, page: int = 1, page_size: int = 20
|
self,
|
||||||
) -> Tuple[List[Dict[str, Any]], int]:
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""获取所有对话,支持分页,按更新时间降序排序"""
|
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
@@ -366,7 +369,7 @@ class SQLiteDatabase:
|
|||||||
"persona_id": persona_id or "",
|
"persona_id": persona_id or "",
|
||||||
"created_at": created_at or 0,
|
"created_at": created_at or 0,
|
||||||
"updated_at": updated_at or 0,
|
"updated_at": updated_at or 0,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return conversations, total_count
|
return conversations, total_count
|
||||||
@@ -381,12 +384,12 @@ class SQLiteDatabase:
|
|||||||
self,
|
self,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 20,
|
page_size: int = 20,
|
||||||
platforms: List[str] = None,
|
platforms: list[str] | None = None,
|
||||||
message_types: List[str] = None,
|
message_types: list[str] | None = None,
|
||||||
search_query: str = None,
|
search_query: str | None = None,
|
||||||
exclude_ids: List[str] = None,
|
exclude_ids: list[str] | None = None,
|
||||||
exclude_platforms: List[str] = None,
|
exclude_platforms: list[str] | None = None,
|
||||||
) -> Tuple[List[Dict[str, Any]], int]:
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""获取筛选后的对话列表"""
|
"""获取筛选后的对话列表"""
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
@@ -422,7 +425,7 @@ class SQLiteDatabase:
|
|||||||
if search_query:
|
if search_query:
|
||||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
where_clauses.append(
|
where_clauses.append(
|
||||||
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
|
||||||
)
|
)
|
||||||
search_param = f"%{search_query}%"
|
search_param = f"%{search_query}%"
|
||||||
params.extend([search_param, search_param, search_param, search_param])
|
params.extend([search_param, search_param, search_param, search_param])
|
||||||
@@ -482,7 +485,7 @@ class SQLiteDatabase:
|
|||||||
"persona_id": persona_id or "",
|
"persona_id": persona_id or "",
|
||||||
"created_at": created_at or 0,
|
"created_at": created_at or 0,
|
||||||
"updated_at": updated_at or 0,
|
"updated_at": updated_at or 0,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return conversations, total_count
|
return conversations, total_count
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
|
JSON,
|
||||||
|
Field,
|
||||||
SQLModel,
|
SQLModel,
|
||||||
Text,
|
Text,
|
||||||
JSON,
|
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
Field,
|
|
||||||
)
|
)
|
||||||
from typing import Optional, TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class PlatformStat(SQLModel, table=True):
|
class PlatformStat(SQLModel, table=True):
|
||||||
@@ -40,7 +40,8 @@ class ConversationV2(SQLModel, table=True):
|
|||||||
__tablename__ = "conversations"
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
inner_conversation_id: int = Field(
|
inner_conversation_id: int = Field(
|
||||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
)
|
)
|
||||||
conversation_id: str = Field(
|
conversation_id: str = Field(
|
||||||
max_length=36,
|
max_length=36,
|
||||||
@@ -50,14 +51,14 @@ class ConversationV2(SQLModel, table=True):
|
|||||||
)
|
)
|
||||||
platform_id: str = Field(nullable=False)
|
platform_id: str = Field(nullable=False)
|
||||||
user_id: str = Field(nullable=False)
|
user_id: str = Field(nullable=False)
|
||||||
content: Optional[list] = Field(default=None, sa_type=JSON)
|
content: list | None = Field(default=None, sa_type=JSON)
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
updated_at: datetime = Field(
|
updated_at: datetime = Field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||||
)
|
)
|
||||||
title: Optional[str] = Field(default=None, max_length=255)
|
title: str | None = Field(default=None, max_length=255)
|
||||||
persona_id: Optional[str] = Field(default=None)
|
persona_id: str | None = Field(default=None)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
@@ -76,13 +77,15 @@ class Persona(SQLModel, table=True):
|
|||||||
__tablename__ = "personas"
|
__tablename__ = "personas"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
persona_id: str = Field(max_length=255, nullable=False)
|
persona_id: str = Field(max_length=255, nullable=False)
|
||||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
begin_dialogs: list | None = Field(default=None, sa_type=JSON)
|
||||||
"""a list of strings, each representing a dialog to start with"""
|
"""a list of strings, each representing a dialog to start with"""
|
||||||
tools: Optional[list] = Field(default=None, sa_type=JSON)
|
tools: list | None = Field(default=None, sa_type=JSON)
|
||||||
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
updated_at: datetime = Field(
|
updated_at: datetime = Field(
|
||||||
@@ -104,7 +107,9 @@ class Preference(SQLModel, table=True):
|
|||||||
__tablename__ = "preferences"
|
__tablename__ = "preferences"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
default=None,
|
||||||
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
)
|
)
|
||||||
scope: str = Field(nullable=False)
|
scope: str = Field(nullable=False)
|
||||||
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
||||||
@@ -138,13 +143,15 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|||||||
__tablename__ = "platform_message_history"
|
__tablename__ = "platform_message_history"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
platform_id: str = Field(nullable=False)
|
platform_id: str = Field(nullable=False)
|
||||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
sender_id: str | None = Field(default=None) # ID of the sender in the platform
|
||||||
sender_name: Optional[str] = Field(
|
sender_name: str | None = Field(
|
||||||
default=None
|
default=None,
|
||||||
) # Name of the sender in the platform
|
) # Name of the sender in the platform
|
||||||
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
@@ -163,7 +170,9 @@ class Attachment(SQLModel, table=True):
|
|||||||
__tablename__ = "attachments"
|
__tablename__ = "attachments"
|
||||||
|
|
||||||
inner_attachment_id: int | None = Field(
|
inner_attachment_id: int | None = Field(
|
||||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
attachment_id: str = Field(
|
attachment_id: str = Field(
|
||||||
max_length=36,
|
max_length=36,
|
||||||
|
|||||||
@@ -1,22 +1,27 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import typing as T
|
|
||||||
import threading
|
import threading
|
||||||
|
import typing as T
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||||
|
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import (
|
||||||
ConversationV2,
|
|
||||||
PlatformStat,
|
|
||||||
PlatformMessageHistory,
|
|
||||||
Attachment,
|
Attachment,
|
||||||
|
ConversationV2,
|
||||||
Persona,
|
Persona,
|
||||||
|
PlatformMessageHistory,
|
||||||
|
PlatformStat,
|
||||||
Preference,
|
Preference,
|
||||||
Stats as DeprecatedStats,
|
|
||||||
Platform as DeprecatedPlatformStat,
|
|
||||||
SQLModel,
|
SQLModel,
|
||||||
)
|
)
|
||||||
|
from astrbot.core.db.po import (
|
||||||
from sqlmodel import select, update, delete, text, func, or_, desc, col
|
Platform as DeprecatedPlatformStat,
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
)
|
||||||
|
from astrbot.core.db.po import (
|
||||||
|
Stats as DeprecatedStats,
|
||||||
|
)
|
||||||
|
|
||||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||||
|
|
||||||
@@ -57,7 +62,9 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with session.begin():
|
async with session.begin():
|
||||||
if timestamp is None:
|
if timestamp is None:
|
||||||
timestamp = datetime.now().replace(
|
timestamp = datetime.now().replace(
|
||||||
minute=0, second=0, microsecond=0
|
minute=0,
|
||||||
|
second=0,
|
||||||
|
microsecond=0,
|
||||||
)
|
)
|
||||||
current_hour = timestamp
|
current_hour = timestamp
|
||||||
await session.execute(
|
await session.execute(
|
||||||
@@ -81,13 +88,13 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
select(func.count(col(PlatformStat.platform_id))).select_from(
|
||||||
PlatformStat
|
PlatformStat,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
count = result.scalar_one_or_none()
|
count = result.scalar_one_or_none()
|
||||||
return count if count is not None else 0
|
return count if count is not None else 0
|
||||||
|
|
||||||
async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]:
|
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
@@ -138,7 +145,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
select(ConversationV2)
|
select(ConversationV2)
|
||||||
.order_by(desc(ConversationV2.created_at))
|
.order_by(desc(ConversationV2.created_at))
|
||||||
.offset(offset)
|
.offset(offset)
|
||||||
.limit(page_size)
|
.limit(page_size),
|
||||||
)
|
)
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
@@ -157,7 +164,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
|
|
||||||
if platform_ids:
|
if platform_ids:
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
col(ConversationV2.platform_id).in_(platform_ids)
|
col(ConversationV2.platform_id).in_(platform_ids),
|
||||||
)
|
)
|
||||||
if search_query:
|
if search_query:
|
||||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
@@ -167,16 +174,16 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
||||||
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
||||||
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||||
for msg_type in kwargs["message_types"]:
|
for msg_type in kwargs["message_types"]:
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
|
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
|
||||||
)
|
)
|
||||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
col(ConversationV2.platform_id).in_(kwargs["platforms"])
|
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get total count matching the filters
|
# Get total count matching the filters
|
||||||
@@ -233,7 +240,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
query = update(ConversationV2).where(
|
query = update(ConversationV2).where(
|
||||||
col(ConversationV2.conversation_id) == cid
|
col(ConversationV2.conversation_id) == cid,
|
||||||
)
|
)
|
||||||
values = {}
|
values = {}
|
||||||
if title is not None:
|
if title is not None:
|
||||||
@@ -243,7 +250,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
if content is not None:
|
if content is not None:
|
||||||
values["content"] = content
|
values["content"] = content
|
||||||
if not values:
|
if not values:
|
||||||
return
|
return None
|
||||||
query = query.values(**values)
|
query = query.values(**values)
|
||||||
await session.execute(query)
|
await session.execute(query)
|
||||||
return await self.get_conversation_by_id(cid)
|
return await self.get_conversation_by_id(cid)
|
||||||
@@ -254,8 +261,8 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with session.begin():
|
async with session.begin():
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(ConversationV2).where(
|
delete(ConversationV2).where(
|
||||||
col(ConversationV2.conversation_id) == cid
|
col(ConversationV2.conversation_id) == cid,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||||
@@ -263,7 +270,9 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
|
delete(ConversationV2).where(
|
||||||
|
col(ConversationV2.user_id) == user_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_session_conversations(
|
async def get_session_conversations(
|
||||||
@@ -282,7 +291,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
select(
|
select(
|
||||||
col(Preference.scope_id).label("session_id"),
|
col(Preference.scope_id).label("session_id"),
|
||||||
func.json_extract(Preference.value, "$.val").label(
|
func.json_extract(Preference.value, "$.val").label(
|
||||||
"conversation_id"
|
"conversation_id",
|
||||||
), # type: ignore
|
), # type: ignore
|
||||||
col(ConversationV2.persona_id).label("persona_id"),
|
col(ConversationV2.persona_id).label("persona_id"),
|
||||||
col(ConversationV2.title).label("title"),
|
col(ConversationV2.title).label("title"),
|
||||||
@@ -295,7 +304,8 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
== ConversationV2.conversation_id,
|
== ConversationV2.conversation_id,
|
||||||
)
|
)
|
||||||
.outerjoin(
|
.outerjoin(
|
||||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
Persona,
|
||||||
|
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||||
)
|
)
|
||||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||||
)
|
)
|
||||||
@@ -308,14 +318,14 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
col(Preference.scope_id).ilike(search_pattern),
|
col(Preference.scope_id).ilike(search_pattern),
|
||||||
col(ConversationV2.title).ilike(search_pattern),
|
col(ConversationV2.title).ilike(search_pattern),
|
||||||
col(Persona.persona_id).ilike(search_pattern),
|
col(Persona.persona_id).ilike(search_pattern),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 平台筛选
|
# 平台筛选
|
||||||
if platform:
|
if platform:
|
||||||
platform_pattern = f"{platform}:%"
|
platform_pattern = f"{platform}:%"
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
col(Preference.scope_id).like(platform_pattern)
|
col(Preference.scope_id).like(platform_pattern),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 排序
|
# 排序
|
||||||
@@ -336,7 +346,8 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
== ConversationV2.conversation_id,
|
== ConversationV2.conversation_id,
|
||||||
)
|
)
|
||||||
.outerjoin(
|
.outerjoin(
|
||||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
Persona,
|
||||||
|
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||||
)
|
)
|
||||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||||
)
|
)
|
||||||
@@ -349,13 +360,13 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
col(Preference.scope_id).ilike(search_pattern),
|
col(Preference.scope_id).ilike(search_pattern),
|
||||||
col(ConversationV2.title).ilike(search_pattern),
|
col(ConversationV2.title).ilike(search_pattern),
|
||||||
col(Persona.persona_id).ilike(search_pattern),
|
col(Persona.persona_id).ilike(search_pattern),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if platform:
|
if platform:
|
||||||
platform_pattern = f"{platform}:%"
|
platform_pattern = f"{platform}:%"
|
||||||
count_base_query = count_base_query.where(
|
count_base_query = count_base_query.where(
|
||||||
col(Preference.scope_id).like(platform_pattern)
|
col(Preference.scope_id).like(platform_pattern),
|
||||||
)
|
)
|
||||||
|
|
||||||
total_result = await session.execute(count_base_query)
|
total_result = await session.execute(count_base_query)
|
||||||
@@ -396,7 +407,10 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
return new_history
|
return new_history
|
||||||
|
|
||||||
async def delete_platform_message_offset(
|
async def delete_platform_message_offset(
|
||||||
self, platform_id, user_id, offset_sec=86400
|
self,
|
||||||
|
platform_id,
|
||||||
|
user_id,
|
||||||
|
offset_sec=86400,
|
||||||
):
|
):
|
||||||
"""Delete platform message history records older than the specified offset."""
|
"""Delete platform message history records older than the specified offset."""
|
||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
@@ -409,11 +423,15 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||||
col(PlatformMessageHistory.user_id) == user_id,
|
col(PlatformMessageHistory.user_id) == user_id,
|
||||||
col(PlatformMessageHistory.created_at) < cutoff_time,
|
col(PlatformMessageHistory.created_at) < cutoff_time,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_platform_message_history(
|
async def get_platform_message_history(
|
||||||
self, platform_id, user_id, page=1, page_size=20
|
self,
|
||||||
|
platform_id,
|
||||||
|
user_id,
|
||||||
|
page=1,
|
||||||
|
page_size=20,
|
||||||
):
|
):
|
||||||
"""Get platform message history records."""
|
"""Get platform message history records."""
|
||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
@@ -452,7 +470,11 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def insert_persona(
|
async def insert_persona(
|
||||||
self, persona_id, system_prompt, begin_dialogs=None, tools=None
|
self,
|
||||||
|
persona_id,
|
||||||
|
system_prompt,
|
||||||
|
begin_dialogs=None,
|
||||||
|
tools=None,
|
||||||
):
|
):
|
||||||
"""Insert a new persona record."""
|
"""Insert a new persona record."""
|
||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
@@ -484,7 +506,11 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
async def update_persona(
|
async def update_persona(
|
||||||
self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
|
self,
|
||||||
|
persona_id,
|
||||||
|
system_prompt=None,
|
||||||
|
begin_dialogs=None,
|
||||||
|
tools=NOT_GIVEN,
|
||||||
):
|
):
|
||||||
"""Update a persona's system prompt or begin dialogs."""
|
"""Update a persona's system prompt or begin dialogs."""
|
||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
@@ -499,7 +525,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
if tools is not NOT_GIVEN:
|
if tools is not NOT_GIVEN:
|
||||||
values["tools"] = tools
|
values["tools"] = tools
|
||||||
if not values:
|
if not values:
|
||||||
return
|
return None
|
||||||
query = query.values(**values)
|
query = query.values(**values)
|
||||||
await session.execute(query)
|
await session.execute(query)
|
||||||
return await self.get_persona_by_id(persona_id)
|
return await self.get_persona_by_id(persona_id)
|
||||||
@@ -510,7 +536,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(Persona).where(col(Persona.persona_id) == persona_id)
|
delete(Persona).where(col(Persona.persona_id) == persona_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||||
@@ -529,7 +555,10 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
existing_preference.value = value
|
existing_preference.value = value
|
||||||
else:
|
else:
|
||||||
new_preference = Preference(
|
new_preference = Preference(
|
||||||
scope=scope, scope_id=scope_id, key=key, value=value
|
scope=scope,
|
||||||
|
scope_id=scope_id,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
)
|
)
|
||||||
session.add(new_preference)
|
session.add(new_preference)
|
||||||
return existing_preference or new_preference
|
return existing_preference or new_preference
|
||||||
@@ -568,7 +597,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
col(Preference.scope) == scope,
|
col(Preference.scope) == scope,
|
||||||
col(Preference.scope_id) == scope_id,
|
col(Preference.scope_id) == scope_id,
|
||||||
col(Preference.key) == key,
|
col(Preference.key) == key,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -581,7 +610,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
delete(Preference).where(
|
delete(Preference).where(
|
||||||
col(Preference.scope) == scope,
|
col(Preference.scope) == scope,
|
||||||
col(Preference.scope_id) == scope_id,
|
col(Preference.scope_id) == scope_id,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -598,7 +627,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
start_time = now - timedelta(seconds=offset_sec)
|
start_time = now - timedelta(seconds=offset_sec)
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(PlatformStat).where(PlatformStat.timestamp >= start_time)
|
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
||||||
)
|
)
|
||||||
all_datas = result.scalars().all()
|
all_datas = result.scalars().all()
|
||||||
deprecated_stats = DeprecatedStats()
|
deprecated_stats = DeprecatedStats()
|
||||||
@@ -608,7 +637,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
name=data.platform_id,
|
name=data.platform_id,
|
||||||
count=data.count,
|
count=data.count,
|
||||||
timestamp=int(data.timestamp.timestamp()),
|
timestamp=int(data.timestamp.timestamp()),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
return deprecated_stats
|
return deprecated_stats
|
||||||
|
|
||||||
@@ -630,7 +659,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat)
|
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||||
)
|
)
|
||||||
total_count = result.scalar_one_or_none()
|
total_count = result.scalar_one_or_none()
|
||||||
return total_count if total_count is not None else 0
|
return total_count if total_count is not None else 0
|
||||||
@@ -656,7 +685,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
||||||
.where(PlatformStat.timestamp >= start_time)
|
.where(PlatformStat.timestamp >= start_time)
|
||||||
.group_by(PlatformStat.platform_id)
|
.group_by(PlatformStat.platform_id),
|
||||||
)
|
)
|
||||||
grouped_stats = result.all()
|
grouped_stats = result.all()
|
||||||
deprecated_stats = DeprecatedStats()
|
deprecated_stats = DeprecatedStats()
|
||||||
@@ -666,7 +695,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
name=platform_id,
|
name=platform_id,
|
||||||
count=count,
|
count=count,
|
||||||
timestamp=int(start_time.timestamp()),
|
timestamp=int(start_time.timestamp()),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
return deprecated_stats
|
return deprecated_stats
|
||||||
|
|
||||||
|
|||||||
@@ -10,22 +10,47 @@ class Result:
|
|||||||
|
|
||||||
class BaseVecDB:
|
class BaseVecDB:
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""
|
"""初始化向量数据库"""
|
||||||
初始化向量数据库
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
async def insert(
|
||||||
"""
|
self,
|
||||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
content: str,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
id: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def insert_batch(
|
||||||
|
self,
|
||||||
|
contents: list[str],
|
||||||
|
metadatas: list[dict] | None = None,
|
||||||
|
ids: list[str] | None = None,
|
||||||
|
batch_size: int = 32,
|
||||||
|
tasks_limit: int = 3,
|
||||||
|
max_retries: int = 3,
|
||||||
|
progress_callback=None,
|
||||||
|
) -> int:
|
||||||
|
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
async def retrieve(
|
||||||
"""
|
self,
|
||||||
搜索最相似的文档。
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
rerank: bool = False,
|
||||||
|
metadata_filters: dict | None = None,
|
||||||
|
) -> list[Result]:
|
||||||
|
"""搜索最相似的文档。
|
||||||
Args:
|
Args:
|
||||||
query (str): 查询文本
|
query (str): 查询文本
|
||||||
top_k (int): 返回的最相似文档的数量
|
top_k (int): 返回的最相似文档的数量
|
||||||
@@ -36,11 +61,13 @@ class BaseVecDB:
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def delete(self, doc_id: str) -> bool:
|
async def delete(self, doc_id: str) -> bool:
|
||||||
"""
|
"""删除指定文档。
|
||||||
删除指定文档。
|
|
||||||
Args:
|
Args:
|
||||||
doc_id (str): 要删除的文档 ID
|
doc_id (str): 要删除的文档 ID
|
||||||
Returns:
|
Returns:
|
||||||
bool: 删除是否成功
|
bool: 删除是否成功
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def close(self): ...
|
||||||
|
|||||||
@@ -1,59 +1,232 @@
|
|||||||
import aiosqlite
|
import json
|
||||||
import os
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlmodel import Field, MetaData, SQLModel, col, func, select, text
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDocModel(SQLModel, table=False):
|
||||||
|
metadata = MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseDocModel, table=True):
|
||||||
|
"""SQLModel for documents table."""
|
||||||
|
|
||||||
|
__tablename__ = "documents" # type: ignore
|
||||||
|
|
||||||
|
id: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
|
)
|
||||||
|
doc_id: str = Field(nullable=False)
|
||||||
|
text: str = Field(nullable=False)
|
||||||
|
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
|
||||||
|
created_at: datetime | None = Field(default=None)
|
||||||
|
updated_at: datetime | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class DocumentStorage:
|
class DocumentStorage:
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.connection = None
|
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||||
|
self.engine: AsyncEngine | None = None
|
||||||
|
self.async_session_maker: sessionmaker | None = None
|
||||||
self.sqlite_init_path = os.path.join(
|
self.sqlite_init_path = os.path.join(
|
||||||
os.path.dirname(__file__), "sqlite_init.sql"
|
os.path.dirname(__file__),
|
||||||
|
"sqlite_init.sql",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||||
if not os.path.exists(self.db_path):
|
await self.connect()
|
||||||
await self.connect()
|
async with self.engine.begin() as conn: # type: ignore
|
||||||
async with self.connection.cursor() as cursor:
|
# Create tables using SQLModel
|
||||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
await conn.run_sync(BaseDocModel.metadata.create_all)
|
||||||
sql_script = f.read()
|
|
||||||
await cursor.executescript(sql_script)
|
try:
|
||||||
await self.connection.commit()
|
await conn.execute(
|
||||||
else:
|
text(
|
||||||
await self.connect()
|
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
|
||||||
|
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"ALTER TABLE documents ADD COLUMN user_id TEXT "
|
||||||
|
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except BaseException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await conn.commit()
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""Connect to the SQLite database."""
|
"""Connect to the SQLite database."""
|
||||||
self.connection = await aiosqlite.connect(self.db_path)
|
if self.engine is None:
|
||||||
|
self.engine = create_async_engine(
|
||||||
|
self.DATABASE_URL,
|
||||||
|
echo=False,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
self.async_session_maker = sessionmaker(
|
||||||
|
self.engine, # type: ignore
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
@asynccontextmanager
|
||||||
|
async def get_session(self):
|
||||||
|
"""Context manager for database sessions."""
|
||||||
|
async with self.async_session_maker() as session: # type: ignore
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def get_documents(
|
||||||
|
self,
|
||||||
|
metadata_filters: dict,
|
||||||
|
ids: list | None = None,
|
||||||
|
offset: int | None = 0,
|
||||||
|
limit: int | None = 100,
|
||||||
|
) -> list[dict]:
|
||||||
"""Retrieve documents by metadata filters and ids.
|
"""Retrieve documents by metadata filters and ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metadata_filters (dict): The metadata filters to apply.
|
metadata_filters (dict): The metadata filters to apply.
|
||||||
|
ids (list | None): Optional list of document IDs to filter.
|
||||||
|
offset (int | None): Offset for pagination.
|
||||||
|
limit (int | None): Limit for pagination.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
list: The list of documents that match the filters.
|
||||||
"""
|
|
||||||
# metadata filter -> SQL WHERE clause
|
|
||||||
where_clauses = []
|
|
||||||
values = []
|
|
||||||
for key, val in metadata_filters.items():
|
|
||||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
|
||||||
values.append(val)
|
|
||||||
if ids is not None and len(ids) > 0:
|
|
||||||
ids = [str(i) for i in ids if i != -1]
|
|
||||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
|
||||||
values.extend(ids)
|
|
||||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
|
||||||
|
|
||||||
result = []
|
"""
|
||||||
async with self.connection.cursor() as cursor:
|
if self.engine is None:
|
||||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
logger.warning(
|
||||||
await cursor.execute(sql, values)
|
"Database connection is not initialized, returning empty result",
|
||||||
for row in await cursor.fetchall():
|
)
|
||||||
result.append(await self.tuple_to_dict(row))
|
return []
|
||||||
return result
|
|
||||||
|
async with self.get_session() as session:
|
||||||
|
query = select(Document)
|
||||||
|
|
||||||
|
for key, val in metadata_filters.items():
|
||||||
|
query = query.where(
|
||||||
|
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
|
||||||
|
).params(**{f"filter_{key}": val})
|
||||||
|
|
||||||
|
if ids is not None and len(ids) > 0:
|
||||||
|
valid_ids = [int(i) for i in ids if i != -1]
|
||||||
|
if valid_ids:
|
||||||
|
query = query.where(col(Document.id).in_(valid_ids))
|
||||||
|
|
||||||
|
if limit is not None:
|
||||||
|
query = query.limit(limit)
|
||||||
|
if offset is not None:
|
||||||
|
query = query.offset(offset)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
documents = result.scalars().all()
|
||||||
|
|
||||||
|
return [self._document_to_dict(doc) for doc in documents]
|
||||||
|
|
||||||
|
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
|
||||||
|
"""Insert a single document and return its integer ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id (str): The document ID (UUID string).
|
||||||
|
text (str): The document text.
|
||||||
|
metadata (dict): The document metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The integer ID of the inserted document.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert self.engine is not None, "Database connection is not initialized."
|
||||||
|
|
||||||
|
async with self.get_session() as session, session.begin():
|
||||||
|
document = Document(
|
||||||
|
doc_id=doc_id,
|
||||||
|
text=text,
|
||||||
|
metadata_=json.dumps(metadata),
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
)
|
||||||
|
session.add(document)
|
||||||
|
await session.flush() # Flush to get the ID
|
||||||
|
return document.id # type: ignore
|
||||||
|
|
||||||
|
async def insert_documents_batch(
|
||||||
|
self,
|
||||||
|
doc_ids: list[str],
|
||||||
|
texts: list[str],
|
||||||
|
metadatas: list[dict],
|
||||||
|
) -> list[int]:
|
||||||
|
"""Batch insert documents and return their integer IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_ids (list[str]): List of document IDs (UUID strings).
|
||||||
|
texts (list[str]): List of document texts.
|
||||||
|
metadatas (list[dict]): List of document metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: List of integer IDs of the inserted documents.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert self.engine is not None, "Database connection is not initialized."
|
||||||
|
|
||||||
|
async with self.get_session() as session, session.begin():
|
||||||
|
import json
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
|
||||||
|
document = Document(
|
||||||
|
doc_id=doc_id,
|
||||||
|
text=text,
|
||||||
|
metadata_=json.dumps(metadata),
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
)
|
||||||
|
documents.append(document)
|
||||||
|
session.add(document)
|
||||||
|
|
||||||
|
await session.flush() # Flush to get all IDs
|
||||||
|
return [doc.id for doc in documents] # type: ignore
|
||||||
|
|
||||||
|
async def delete_document_by_doc_id(self, doc_id: str):
|
||||||
|
"""Delete a document by its doc_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id (str): The doc_id of the document to delete.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert self.engine is not None, "Database connection is not initialized."
|
||||||
|
|
||||||
|
async with self.get_session() as session, session.begin():
|
||||||
|
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if document:
|
||||||
|
await session.delete(document)
|
||||||
|
|
||||||
async def get_document_by_doc_id(self, doc_id: str):
|
async def get_document_by_doc_id(self, doc_id: str):
|
||||||
"""Retrieve a document by its doc_id.
|
"""Retrieve a document by its doc_id.
|
||||||
@@ -62,40 +235,134 @@ class DocumentStorage:
|
|||||||
doc_id (str): The doc_id of the document to retrieve.
|
doc_id (str): The doc_id of the document to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: The document data.
|
dict: The document data or None if not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
async with self.connection.cursor() as cursor:
|
assert self.engine is not None, "Database connection is not initialized."
|
||||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
|
||||||
row = await cursor.fetchone()
|
async with self.get_session() as session:
|
||||||
if row:
|
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||||
return await self.tuple_to_dict(row)
|
result = await session.execute(query)
|
||||||
else:
|
document = result.scalar_one_or_none()
|
||||||
return None
|
|
||||||
|
if document:
|
||||||
|
return self._document_to_dict(document)
|
||||||
|
return None
|
||||||
|
|
||||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||||
"""Retrieve a document by its doc_id.
|
"""Update a document by its doc_id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
doc_id (str): The doc_id.
|
doc_id (str): The doc_id.
|
||||||
new_text (str): The new text to update the document with.
|
new_text (str): The new text to update the document with.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
async with self.connection.cursor() as cursor:
|
assert self.engine is not None, "Database connection is not initialized."
|
||||||
await cursor.execute(
|
|
||||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
async with self.get_session() as session, session.begin():
|
||||||
|
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
document = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if document:
|
||||||
|
document.text = new_text
|
||||||
|
document.updated_at = datetime.now()
|
||||||
|
session.add(document)
|
||||||
|
|
||||||
|
async def delete_documents(self, metadata_filters: dict):
|
||||||
|
"""Delete documents by their metadata filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_filters (dict): The metadata filters to apply.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.engine is None:
|
||||||
|
logger.warning(
|
||||||
|
"Database connection is not initialized, skipping delete operation",
|
||||||
)
|
)
|
||||||
await self.connection.commit()
|
return
|
||||||
|
|
||||||
|
async with self.get_session() as session, session.begin():
|
||||||
|
query = select(Document)
|
||||||
|
|
||||||
|
for key, val in metadata_filters.items():
|
||||||
|
query = query.where(
|
||||||
|
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
|
||||||
|
).params(**{f"filter_{key}": val})
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
documents = result.scalars().all()
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
await session.delete(doc)
|
||||||
|
|
||||||
|
async def count_documents(self, metadata_filters: dict | None = None) -> int:
|
||||||
|
"""Count documents in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_filters (dict | None): Metadata filters to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The count of documents.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.engine is None:
|
||||||
|
logger.warning("Database connection is not initialized, returning 0")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async with self.get_session() as session:
|
||||||
|
query = select(func.count(col(Document.id)))
|
||||||
|
|
||||||
|
if metadata_filters:
|
||||||
|
for key, val in metadata_filters.items():
|
||||||
|
query = query.where(
|
||||||
|
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
|
||||||
|
).params(**{f"filter_{key}": val})
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
count = result.scalar_one_or_none()
|
||||||
|
return count if count is not None else 0
|
||||||
|
|
||||||
async def get_user_ids(self) -> list[str]:
|
async def get_user_ids(self) -> list[str]:
|
||||||
"""Retrieve all user IDs from the documents table.
|
"""Retrieve all user IDs from the documents table.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of user IDs.
|
list: A list of user IDs.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
async with self.connection.cursor() as cursor:
|
assert self.engine is not None, "Database connection is not initialized."
|
||||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
|
||||||
rows = await cursor.fetchall()
|
async with self.get_session() as session:
|
||||||
|
query = text(
|
||||||
|
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL",
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
rows = result.fetchall()
|
||||||
return [row[0] for row in rows]
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
|
def _document_to_dict(self, document: Document) -> dict:
|
||||||
|
"""Convert a Document model to a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document (Document): The document to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The converted dictionary.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": document.id,
|
||||||
|
"doc_id": document.doc_id,
|
||||||
|
"text": document.text,
|
||||||
|
"metadata": document.metadata_,
|
||||||
|
"created_at": document.created_at.isoformat()
|
||||||
|
if isinstance(document.created_at, datetime)
|
||||||
|
else document.created_at,
|
||||||
|
"updated_at": document.updated_at.isoformat()
|
||||||
|
if isinstance(document.updated_at, datetime)
|
||||||
|
else document.updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
async def tuple_to_dict(self, row):
|
async def tuple_to_dict(self, row):
|
||||||
"""Convert a tuple to a dictionary.
|
"""Convert a tuple to a dictionary.
|
||||||
|
|
||||||
@@ -104,6 +371,9 @@ class DocumentStorage:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: The converted dictionary.
|
dict: The converted dictionary.
|
||||||
|
|
||||||
|
Note: This method is kept for backward compatibility but is no longer used internally.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"id": row[0],
|
"id": row[0],
|
||||||
@@ -116,6 +386,7 @@ class DocumentStorage:
|
|||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the connection to the SQLite database."""
|
"""Close the connection to the SQLite database."""
|
||||||
if self.connection:
|
if self.engine:
|
||||||
await self.connection.close()
|
await self.engine.dispose()
|
||||||
self.connection = None
|
self.engine = None
|
||||||
|
self.async_session_maker = None
|
||||||
|
|||||||
@@ -2,14 +2,15 @@ try:
|
|||||||
import faiss
|
import faiss
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。",
|
||||||
)
|
)
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingStorage:
|
class EmbeddingStorage:
|
||||||
def __init__(self, dimension: int, path: str = None):
|
def __init__(self, dimension: int, path: str | None = None):
|
||||||
self.dimension = dimension
|
self.dimension = dimension
|
||||||
self.path = path
|
self.path = path
|
||||||
self.index = None
|
self.index = None
|
||||||
@@ -18,7 +19,6 @@ class EmbeddingStorage:
|
|||||||
else:
|
else:
|
||||||
base_index = faiss.IndexFlatL2(dimension)
|
base_index = faiss.IndexFlatL2(dimension)
|
||||||
self.index = faiss.IndexIDMap(base_index)
|
self.index = faiss.IndexIDMap(base_index)
|
||||||
self.storage = {}
|
|
||||||
|
|
||||||
async def insert(self, vector: np.ndarray, id: int):
|
async def insert(self, vector: np.ndarray, id: int):
|
||||||
"""插入向量
|
"""插入向量
|
||||||
@@ -28,13 +28,32 @@ class EmbeddingStorage:
|
|||||||
id (int): 向量的ID
|
id (int): 向量的ID
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果向量的维度与存储的维度不匹配
|
ValueError: 如果向量的维度与存储的维度不匹配
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
assert self.index is not None, "FAISS index is not initialized."
|
||||||
if vector.shape[0] != self.dimension:
|
if vector.shape[0] != self.dimension:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
|
||||||
)
|
)
|
||||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||||
self.storage[id] = vector
|
await self.save_index()
|
||||||
|
|
||||||
|
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
|
||||||
|
"""批量插入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors (np.ndarray): 要插入的向量数组
|
||||||
|
ids (list[int]): 向量的ID列表
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果向量的维度与存储的维度不匹配
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert self.index is not None, "FAISS index is not initialized."
|
||||||
|
if vectors.shape[1] != self.dimension:
|
||||||
|
raise ValueError(
|
||||||
|
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
|
||||||
|
)
|
||||||
|
self.index.add_with_ids(vectors, np.array(ids))
|
||||||
await self.save_index()
|
await self.save_index()
|
||||||
|
|
||||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||||
@@ -45,15 +64,30 @@ class EmbeddingStorage:
|
|||||||
k (int): 返回的最相似向量的数量
|
k (int): 返回的最相似向量的数量
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (距离, 索引)
|
tuple: (距离, 索引)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
assert self.index is not None, "FAISS index is not initialized."
|
||||||
faiss.normalize_L2(vector)
|
faiss.normalize_L2(vector)
|
||||||
distances, indices = self.index.search(vector, k)
|
distances, indices = self.index.search(vector, k)
|
||||||
return distances, indices
|
return distances, indices
|
||||||
|
|
||||||
|
async def delete(self, ids: list[int]):
|
||||||
|
"""删除向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids (list[int]): 要删除的向量ID列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert self.index is not None, "FAISS index is not initialized."
|
||||||
|
id_array = np.array(ids, dtype=np.int64)
|
||||||
|
self.index.remove_ids(id_array)
|
||||||
|
await self.save_index()
|
||||||
|
|
||||||
async def save_index(self):
|
async def save_index(self):
|
||||||
"""保存索引
|
"""保存索引
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): 保存索引的路径
|
path (str): 保存索引的路径
|
||||||
|
|
||||||
"""
|
"""
|
||||||
faiss.write_index(self.index, self.path)
|
faiss.write_index(self.index, self.path)
|
||||||
|
|||||||
@@ -1,17 +1,18 @@
|
|||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||||
|
|
||||||
|
from ..base import BaseVecDB, Result
|
||||||
from .document_storage import DocumentStorage
|
from .document_storage import DocumentStorage
|
||||||
from .embedding_storage import EmbeddingStorage
|
from .embedding_storage import EmbeddingStorage
|
||||||
from ..base import Result, BaseVecDB
|
|
||||||
from astrbot.core.provider.provider import EmbeddingProvider
|
|
||||||
from astrbot.core.provider.provider import RerankProvider
|
|
||||||
|
|
||||||
|
|
||||||
class FaissVecDB(BaseVecDB):
|
class FaissVecDB(BaseVecDB):
|
||||||
"""
|
"""A class to represent a vector database."""
|
||||||
A class to represent a vector database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -25,7 +26,8 @@ class FaissVecDB(BaseVecDB):
|
|||||||
self.embedding_provider = embedding_provider
|
self.embedding_provider = embedding_provider
|
||||||
self.document_storage = DocumentStorage(doc_store_path)
|
self.document_storage = DocumentStorage(doc_store_path)
|
||||||
self.embedding_storage = EmbeddingStorage(
|
self.embedding_storage = EmbeddingStorage(
|
||||||
embedding_provider.get_dim(), index_store_path
|
embedding_provider.get_dim(),
|
||||||
|
index_store_path,
|
||||||
)
|
)
|
||||||
self.embedding_provider = embedding_provider
|
self.embedding_provider = embedding_provider
|
||||||
self.rerank_provider = rerank_provider
|
self.rerank_provider = rerank_provider
|
||||||
@@ -34,28 +36,69 @@ class FaissVecDB(BaseVecDB):
|
|||||||
await self.document_storage.initialize()
|
await self.document_storage.initialize()
|
||||||
|
|
||||||
async def insert(
|
async def insert(
|
||||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
self,
|
||||||
|
content: str,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
id: str | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
|
||||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
|
||||||
"""
|
|
||||||
metadata = metadata or {}
|
metadata = metadata or {}
|
||||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||||
|
|
||||||
vector = await self.embedding_provider.get_embedding(content)
|
vector = await self.embedding_provider.get_embedding(content)
|
||||||
vector = np.array(vector, dtype=np.float32)
|
vector = np.array(vector, dtype=np.float32)
|
||||||
async with self.document_storage.connection.cursor() as cursor:
|
|
||||||
await cursor.execute(
|
|
||||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
|
||||||
(str_id, content, json.dumps(metadata)),
|
|
||||||
)
|
|
||||||
await self.document_storage.connection.commit()
|
|
||||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
|
||||||
int_id = result["id"]
|
|
||||||
|
|
||||||
# 插入向量到 FAISS
|
# 使用 DocumentStorage 的方法插入文档
|
||||||
await self.embedding_storage.insert(vector, int_id)
|
int_id = await self.document_storage.insert_document(str_id, content, metadata)
|
||||||
return int_id
|
|
||||||
|
# 插入向量到 FAISS
|
||||||
|
await self.embedding_storage.insert(vector, int_id)
|
||||||
|
return int_id
|
||||||
|
|
||||||
|
async def insert_batch(
|
||||||
|
self,
|
||||||
|
contents: list[str],
|
||||||
|
metadatas: list[dict] | None = None,
|
||||||
|
ids: list[str] | None = None,
|
||||||
|
batch_size: int = 32,
|
||||||
|
tasks_limit: int = 3,
|
||||||
|
max_retries: int = 3,
|
||||||
|
progress_callback=None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||||
|
|
||||||
|
"""
|
||||||
|
metadatas = metadatas or [{} for _ in contents]
|
||||||
|
ids = ids or [str(uuid.uuid4()) for _ in contents]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
logger.debug(f"Generating embeddings for {len(contents)} contents...")
|
||||||
|
vectors = await self.embedding_provider.get_embeddings_batch(
|
||||||
|
contents,
|
||||||
|
batch_size=batch_size,
|
||||||
|
tasks_limit=tasks_limit,
|
||||||
|
max_retries=max_retries,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 DocumentStorage 的批量插入方法
|
||||||
|
int_ids = await self.document_storage.insert_documents_batch(
|
||||||
|
ids,
|
||||||
|
contents,
|
||||||
|
metadatas,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 批量插入向量到 FAISS
|
||||||
|
vectors_array = np.array(vectors).astype("float32")
|
||||||
|
await self.embedding_storage.insert_batch(vectors_array, int_ids)
|
||||||
|
return int_ids
|
||||||
|
|
||||||
async def retrieve(
|
async def retrieve(
|
||||||
self,
|
self,
|
||||||
@@ -65,8 +108,7 @@ class FaissVecDB(BaseVecDB):
|
|||||||
rerank: bool = False,
|
rerank: bool = False,
|
||||||
metadata_filters: dict | None = None,
|
metadata_filters: dict | None = None,
|
||||||
) -> list[Result]:
|
) -> list[Result]:
|
||||||
"""
|
"""搜索最相似的文档。
|
||||||
搜索最相似的文档。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): 查询文本
|
query (str): 查询文本
|
||||||
@@ -77,6 +119,7 @@ class FaissVecDB(BaseVecDB):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Result]: 查询结果
|
List[Result]: 查询结果
|
||||||
|
|
||||||
"""
|
"""
|
||||||
embedding = await self.embedding_provider.get_embedding(query)
|
embedding = await self.embedding_provider.get_embedding(query)
|
||||||
scores, indices = await self.embedding_storage.search(
|
scores, indices = await self.embedding_storage.search(
|
||||||
@@ -89,7 +132,8 @@ class FaissVecDB(BaseVecDB):
|
|||||||
scores[0] = 1.0 - (scores[0] / 2.0)
|
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||||
# NOTE: maybe the size is less than k.
|
# NOTE: maybe the size is less than k.
|
||||||
fetched_docs = await self.document_storage.get_documents(
|
fetched_docs = await self.document_storage.get_documents(
|
||||||
metadata_filters=metadata_filters or {}, ids=indices[0]
|
metadata_filters=metadata_filters or {},
|
||||||
|
ids=indices[0],
|
||||||
)
|
)
|
||||||
if not fetched_docs:
|
if not fetched_docs:
|
||||||
return []
|
return []
|
||||||
@@ -110,7 +154,9 @@ class FaissVecDB(BaseVecDB):
|
|||||||
documents = [doc.data["text"] for doc in top_k_results]
|
documents = [doc.data["text"] for doc in top_k_results]
|
||||||
reranked_results = await self.rerank_provider.rerank(query, documents)
|
reranked_results = await self.rerank_provider.rerank(query, documents)
|
||||||
reranked_results = sorted(
|
reranked_results = sorted(
|
||||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
reranked_results,
|
||||||
|
key=lambda x: x.relevance_score,
|
||||||
|
reverse=True,
|
||||||
)
|
)
|
||||||
top_k_results = [
|
top_k_results = [
|
||||||
top_k_results[reranked_result.index]
|
top_k_results[reranked_result.index]
|
||||||
@@ -119,23 +165,40 @@ class FaissVecDB(BaseVecDB):
|
|||||||
|
|
||||||
return top_k_results
|
return top_k_results
|
||||||
|
|
||||||
async def delete(self, doc_id: int):
|
async def delete(self, doc_id: str):
|
||||||
"""
|
"""删除一条文档块(chunk)"""
|
||||||
删除一条文档
|
# 获得对应的 int id
|
||||||
"""
|
result = await self.document_storage.get_document_by_doc_id(doc_id)
|
||||||
await self.document_storage.connection.execute(
|
int_id = result["id"] if result else None
|
||||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
if int_id is None:
|
||||||
)
|
return
|
||||||
await self.document_storage.connection.commit()
|
|
||||||
|
# 使用 DocumentStorage 的删除方法
|
||||||
|
await self.document_storage.delete_document_by_doc_id(doc_id)
|
||||||
|
await self.embedding_storage.delete([int_id])
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
await self.document_storage.close()
|
await self.document_storage.close()
|
||||||
|
|
||||||
async def count_documents(self) -> int:
|
async def count_documents(self, metadata_filter: dict | None = None) -> int:
|
||||||
|
"""计算文档数量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_filter (dict | None): 元数据过滤器
|
||||||
|
|
||||||
"""
|
"""
|
||||||
计算文档数量
|
count = await self.document_storage.count_documents(
|
||||||
"""
|
metadata_filters=metadata_filter or {},
|
||||||
async with self.document_storage.connection.cursor() as cursor:
|
)
|
||||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
return count
|
||||||
count = await cursor.fetchone()
|
|
||||||
return count[0] if count else 0
|
async def delete_documents(self, metadata_filters: dict):
|
||||||
|
"""根据元数据过滤器删除文档"""
|
||||||
|
docs = await self.document_storage.get_documents(
|
||||||
|
metadata_filters=metadata_filters,
|
||||||
|
offset=None,
|
||||||
|
limit=None,
|
||||||
|
)
|
||||||
|
doc_ids: list[int] = [doc["id"] for doc in docs]
|
||||||
|
await self.embedding_storage.delete(doc_ids)
|
||||||
|
await self.document_storage.delete_documents(metadata_filters=metadata_filters)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""
|
"""事件总线, 用于处理事件的分发和处理
|
||||||
事件总线, 用于处理事件的分发和处理
|
|
||||||
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||||
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||||
|
|
||||||
@@ -13,10 +12,12 @@ class:
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from .platform import AstrMessageEvent
|
|
||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
|
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||||
|
|
||||||
|
from .platform import AstrMessageEvent
|
||||||
|
|
||||||
|
|
||||||
class EventBus:
|
class EventBus:
|
||||||
@@ -46,14 +47,15 @@ class EventBus:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (AstrMessageEvent): 事件对象
|
event (AstrMessageEvent): 事件对象
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||||
if event.get_sender_name():
|
if event.get_sender_name():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}",
|
||||||
)
|
)
|
||||||
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
|
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
from urllib.parse import urlparse, unquote
|
|
||||||
import platform
|
import platform
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
|
|
||||||
class FileTokenService:
|
class FileTokenService:
|
||||||
@@ -23,7 +23,12 @@ class FileTokenService:
|
|||||||
for token in expired_tokens:
|
for token in expired_tokens:
|
||||||
self.staged_files.pop(token, None)
|
self.staged_files.pop(token, None)
|
||||||
|
|
||||||
async def register_file(self, file_path: str, timeout: float = None) -> str:
|
async def check_token_expired(self, file_token: str) -> bool:
|
||||||
|
async with self.lock:
|
||||||
|
await self._cleanup_expired_tokens()
|
||||||
|
return file_token not in self.staged_files
|
||||||
|
|
||||||
|
async def register_file(self, file_path: str, timeout: float | None = None) -> str:
|
||||||
"""向令牌服务注册一个文件。
|
"""向令牌服务注册一个文件。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -35,8 +40,8 @@ class FileTokenService:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: 当路径不存在时抛出
|
FileNotFoundError: 当路径不存在时抛出
|
||||||
"""
|
|
||||||
|
|
||||||
|
"""
|
||||||
# 处理 file:///
|
# 处理 file:///
|
||||||
try:
|
try:
|
||||||
parsed_uri = urlparse(file_path)
|
parsed_uri = urlparse(file_path)
|
||||||
@@ -56,7 +61,7 @@ class FileTokenService:
|
|||||||
|
|
||||||
if not os.path.exists(local_path):
|
if not os.path.exists(local_path):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"文件不存在: {local_path} (原始输入: {file_path})"
|
f"文件不存在: {local_path} (原始输入: {file_path})",
|
||||||
)
|
)
|
||||||
|
|
||||||
file_token = str(uuid.uuid4())
|
file_token = str(uuid.uuid4())
|
||||||
@@ -79,6 +84,7 @@ class FileTokenService:
|
|||||||
Raises:
|
Raises:
|
||||||
KeyError: 当令牌不存在或已过期时抛出
|
KeyError: 当令牌不存在或已过期时抛出
|
||||||
FileNotFoundError: 当文件本身已被删除时抛出
|
FileNotFoundError: 当文件本身已被删除时抛出
|
||||||
|
|
||||||
"""
|
"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
await self._cleanup_expired_tokens()
|
await self._cleanup_expired_tokens()
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""
|
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||||
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
|
||||||
|
|
||||||
工作流程:
|
工作流程:
|
||||||
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
|
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
|
||||||
@@ -8,10 +7,10 @@ AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from astrbot.core import logger
|
|
||||||
|
from astrbot.core import LogBroker, logger
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core import LogBroker
|
|
||||||
from astrbot.dashboard.server import AstrBotDashboard
|
from astrbot.dashboard.server import AstrBotDashboard
|
||||||
|
|
||||||
|
|
||||||
@@ -39,12 +38,18 @@ class InitialLoader:
|
|||||||
webui_dir = self.webui_dir
|
webui_dir = self.webui_dir
|
||||||
|
|
||||||
self.dashboard_server = AstrBotDashboard(
|
self.dashboard_server = AstrBotDashboard(
|
||||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
|
core_lifecycle,
|
||||||
|
self.db,
|
||||||
|
core_lifecycle.dashboard_shutdown_event,
|
||||||
|
webui_dir,
|
||||||
)
|
)
|
||||||
task = asyncio.gather(
|
|
||||||
core_task, self.dashboard_server.run()
|
|
||||||
) # 启动核心任务和仪表板服务器
|
|
||||||
|
|
||||||
|
coro = self.dashboard_server.run()
|
||||||
|
if coro:
|
||||||
|
# 启动核心任务和仪表板服务器
|
||||||
|
task = asyncio.gather(core_task, coro)
|
||||||
|
else:
|
||||||
|
task = core_task
|
||||||
try:
|
try:
|
||||||
await task # 整个AstrBot在这里运行
|
await task # 整个AstrBot在这里运行
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|||||||
9
astrbot/core/knowledge_base/chunking/__init__.py
Normal file
9
astrbot/core/knowledge_base/chunking/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""文档分块模块"""
|
||||||
|
|
||||||
|
from .base import BaseChunker
|
||||||
|
from .fixed_size import FixedSizeChunker
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseChunker",
|
||||||
|
"FixedSizeChunker",
|
||||||
|
]
|
||||||
25
astrbot/core/knowledge_base/chunking/base.py
Normal file
25
astrbot/core/knowledge_base/chunking/base.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""文档分块器基类
|
||||||
|
|
||||||
|
定义了文档分块处理的抽象接口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChunker(ABC):
|
||||||
|
"""分块器基类
|
||||||
|
|
||||||
|
所有分块器都应该继承此类并实现 chunk 方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||||
|
"""将文本分块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: 分块后的文本列表
|
||||||
|
|
||||||
|
"""
|
||||||
59
astrbot/core/knowledge_base/chunking/fixed_size.py
Normal file
59
astrbot/core/knowledge_base/chunking/fixed_size.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""固定大小分块器
|
||||||
|
|
||||||
|
按照固定的字符数将文本分块,支持重叠区域。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import BaseChunker
|
||||||
|
|
||||||
|
|
||||||
|
class FixedSizeChunker(BaseChunker):
|
||||||
|
"""固定大小分块器
|
||||||
|
|
||||||
|
按照固定的字符数分块,并支持块之间的重叠。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
|
||||||
|
"""初始化分块器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: 块的大小(字符数)
|
||||||
|
chunk_overlap: 块之间的重叠字符数
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||||
|
"""固定大小分块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
chunk_size: 每个文本块的最大大小
|
||||||
|
chunk_overlap: 每个文本块之间的重叠部分大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: 分块后的文本列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
chunk_size = kwargs.get("chunk_size", self.chunk_size)
|
||||||
|
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
text_len = len(text)
|
||||||
|
|
||||||
|
while start < text_len:
|
||||||
|
end = start + chunk_size
|
||||||
|
chunk = text[start:end]
|
||||||
|
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# 移动窗口,保留重叠部分
|
||||||
|
start = end - chunk_overlap
|
||||||
|
|
||||||
|
# 防止无限循环: 如果重叠过大,直接移到end
|
||||||
|
if start >= end or chunk_overlap >= chunk_size:
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return chunks
|
||||||
161
astrbot/core/knowledge_base/chunking/recursive.py
Normal file
161
astrbot/core/knowledge_base/chunking/recursive.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from .base import BaseChunker
|
||||||
|
|
||||||
|
|
||||||
|
class RecursiveCharacterChunker(BaseChunker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
chunk_overlap: int = 100,
|
||||||
|
length_function: Callable[[str], int] = len,
|
||||||
|
is_separator_regex: bool = False,
|
||||||
|
separators: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""初始化递归字符文本分割器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: 每个文本块的最大大小
|
||||||
|
chunk_overlap: 每个文本块之间的重叠部分大小
|
||||||
|
length_function: 计算文本长度的函数
|
||||||
|
is_separator_regex: 分隔符是否为正则表达式
|
||||||
|
separators: 用于分割文本的分隔符列表,按优先级排序
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.chunk_overlap = chunk_overlap
|
||||||
|
self.length_function = length_function
|
||||||
|
self.is_separator_regex = is_separator_regex
|
||||||
|
|
||||||
|
# 默认分隔符列表,按优先级从高到低
|
||||||
|
self.separators = separators or [
|
||||||
|
"\n\n", # 段落
|
||||||
|
"\n", # 换行
|
||||||
|
"。", # 中文句子
|
||||||
|
",", # 中文逗号
|
||||||
|
". ", # 句子
|
||||||
|
", ", # 逗号分隔
|
||||||
|
" ", # 单词
|
||||||
|
"", # 字符
|
||||||
|
]
|
||||||
|
|
||||||
|
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||||
|
"""递归地将文本分割成块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要分割的文本
|
||||||
|
chunk_size: 每个文本块的最大大小
|
||||||
|
chunk_overlap: 每个文本块之间的重叠部分大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分割后的文本块列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
|
||||||
|
chunk_size = kwargs.get("chunk_size", self.chunk_size)
|
||||||
|
|
||||||
|
text_length = self.length_function(text)
|
||||||
|
if text_length <= chunk_size:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
for separator in self.separators:
|
||||||
|
if separator == "":
|
||||||
|
return self._split_by_character(text, chunk_size, overlap)
|
||||||
|
|
||||||
|
if separator in text:
|
||||||
|
splits = text.split(separator)
|
||||||
|
# 重新添加分隔符(除了最后一个片段)
|
||||||
|
splits = [s + separator for s in splits[:-1]] + [splits[-1]]
|
||||||
|
splits = [s for s in splits if s]
|
||||||
|
if len(splits) == 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 递归合并分割后的文本块
|
||||||
|
final_chunks = []
|
||||||
|
current_chunk = []
|
||||||
|
current_chunk_length = 0
|
||||||
|
|
||||||
|
for split in splits:
|
||||||
|
split_length = self.length_function(split)
|
||||||
|
|
||||||
|
# 如果单个分割部分已经超过了chunk_size,需要递归分割
|
||||||
|
if split_length > chunk_size:
|
||||||
|
# 先处理当前积累的块
|
||||||
|
if current_chunk:
|
||||||
|
combined_text = "".join(current_chunk)
|
||||||
|
final_chunks.extend(
|
||||||
|
await self.chunk(
|
||||||
|
combined_text,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=overlap,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
current_chunk = []
|
||||||
|
current_chunk_length = 0
|
||||||
|
|
||||||
|
# 递归分割过大的部分
|
||||||
|
final_chunks.extend(
|
||||||
|
await self.chunk(
|
||||||
|
split,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=overlap,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# 如果添加这部分会使当前块超过chunk_size
|
||||||
|
elif current_chunk_length + split_length > chunk_size:
|
||||||
|
# 合并当前块并添加到结果中
|
||||||
|
combined_text = "".join(current_chunk)
|
||||||
|
final_chunks.append(combined_text)
|
||||||
|
|
||||||
|
# 处理重叠部分
|
||||||
|
overlap_start = max(0, len(combined_text) - overlap)
|
||||||
|
if overlap_start > 0:
|
||||||
|
overlap_text = combined_text[overlap_start:]
|
||||||
|
current_chunk = [overlap_text, split]
|
||||||
|
current_chunk_length = (
|
||||||
|
self.length_function(overlap_text) + split_length
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_chunk = [split]
|
||||||
|
current_chunk_length = split_length
|
||||||
|
else:
|
||||||
|
# 添加到当前块
|
||||||
|
current_chunk.append(split)
|
||||||
|
current_chunk_length += split_length
|
||||||
|
|
||||||
|
# 处理剩余的块
|
||||||
|
if current_chunk:
|
||||||
|
final_chunks.append("".join(current_chunk))
|
||||||
|
|
||||||
|
return final_chunks
|
||||||
|
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
def _split_by_character(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
overlap: int | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""按字符级别分割文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要分割的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分割后的文本块列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
chunk_size = chunk_size or self.chunk_size
|
||||||
|
overlap = overlap or self.chunk_overlap
|
||||||
|
result = []
|
||||||
|
for i in range(0, len(text), chunk_size - overlap):
|
||||||
|
end = min(i + chunk_size, len(text))
|
||||||
|
result.append(text[i:end])
|
||||||
|
if end == len(text):
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
301
astrbot/core/knowledge_base/kb_db_sqlite.py
Normal file
301
astrbot/core/knowledge_base/kb_db_sqlite.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import delete, func, select, text, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlmodel import col, desc
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||||
|
from astrbot.core.knowledge_base.models import (
|
||||||
|
BaseKBModel,
|
||||||
|
KBDocument,
|
||||||
|
KBMedia,
|
||||||
|
KnowledgeBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KBSQLiteDatabase:
|
||||||
|
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||||
|
"""初始化知识库数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||||
|
self.inited = False
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 创建异步引擎
|
||||||
|
self.engine = create_async_engine(
|
||||||
|
self.DATABASE_URL,
|
||||||
|
echo=False,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建会话工厂
|
||||||
|
self.async_session = async_sessionmaker(
|
||||||
|
self.engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db(self):
|
||||||
|
"""获取数据库会话
|
||||||
|
|
||||||
|
用法:
|
||||||
|
async with kb_db.get_db() as session:
|
||||||
|
# 执行数据库操作
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
"""
|
||||||
|
async with self.async_session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||||
|
async with self.engine.begin() as conn:
|
||||||
|
# 创建所有知识库相关表
|
||||||
|
await conn.run_sync(BaseKBModel.metadata.create_all)
|
||||||
|
|
||||||
|
# 配置 SQLite 性能优化参数
|
||||||
|
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||||
|
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||||
|
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||||
|
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||||
|
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||||
|
await conn.execute(text("PRAGMA optimize"))
|
||||||
|
await conn.commit()
|
||||||
|
|
||||||
|
self.inited = True
|
||||||
|
|
||||||
|
async def migrate_to_v1(self) -> None:
|
||||||
|
"""执行知识库数据库 v1 迁移
|
||||||
|
|
||||||
|
创建所有必要的索引以优化查询性能
|
||||||
|
"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
# 创建知识库表索引
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||||
|
"ON knowledge_bases(kb_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||||
|
"ON knowledge_bases(kb_name)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||||
|
"ON knowledge_bases(created_at)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建文档表索引
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||||
|
"ON kb_documents(doc_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||||
|
"ON kb_documents(kb_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||||
|
"ON kb_documents(doc_name)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||||
|
"ON kb_documents(file_type)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||||
|
"ON kb_documents(created_at)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建多媒体表索引
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||||
|
"ON kb_media(media_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||||
|
"ON kb_media(doc_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||||
|
"ON kb_media(media_type)",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭数据库连接"""
|
||||||
|
await self.engine.dispose()
|
||||||
|
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||||
|
|
||||||
|
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
|
||||||
|
"""根据 ID 获取知识库"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
|
||||||
|
"""根据名称获取知识库"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||||
|
"""列出所有知识库"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = (
|
||||||
|
select(KnowledgeBase)
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.order_by(desc(KnowledgeBase.created_at))
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def count_kbs(self) -> int:
|
||||||
|
"""统计知识库数量"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = select(func.count(col(KnowledgeBase.id)))
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
# ===== 文档查询 =====
|
||||||
|
|
||||||
|
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
|
||||||
|
"""根据 ID 获取文档"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def list_documents_by_kb(
|
||||||
|
self,
|
||||||
|
kb_id: str,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> list[KBDocument]:
|
||||||
|
"""列出知识库的所有文档"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = (
|
||||||
|
select(KBDocument)
|
||||||
|
.where(col(KBDocument.kb_id) == kb_id)
|
||||||
|
.offset(offset)
|
||||||
|
.limit(limit)
|
||||||
|
.order_by(desc(KBDocument.created_at))
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def count_documents_by_kb(self, kb_id: str) -> int:
|
||||||
|
"""统计知识库的文档数量"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = select(func.count(col(KBDocument.id))).where(
|
||||||
|
col(KBDocument.kb_id) == kb_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = (
|
||||||
|
select(KBDocument, KnowledgeBase)
|
||||||
|
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
|
||||||
|
.where(col(KBDocument.doc_id) == doc_id)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
row = result.first()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"document": row[0],
|
||||||
|
"knowledge_base": row[1],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
|
||||||
|
"""删除单个文档及其相关数据"""
|
||||||
|
# 在知识库表中删除
|
||||||
|
async with self.get_db() as session, session.begin():
|
||||||
|
# 删除文档记录
|
||||||
|
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||||
|
await session.execute(delete_stmt)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# 在 vec db 中删除相关向量
|
||||||
|
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
|
||||||
|
|
||||||
|
# ===== 多媒体查询 =====
|
||||||
|
|
||||||
|
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
||||||
|
"""列出文档的所有多媒体资源"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
|
||||||
|
"""根据 ID 获取多媒体资源"""
|
||||||
|
async with self.get_db() as session:
|
||||||
|
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
|
||||||
|
"""更新知识库统计信息"""
|
||||||
|
chunk_cnt = await vec_db.count_documents()
|
||||||
|
|
||||||
|
async with self.get_db() as session, session.begin():
|
||||||
|
update_stmt = (
|
||||||
|
update(KnowledgeBase)
|
||||||
|
.where(col(KnowledgeBase.kb_id) == kb_id)
|
||||||
|
.values(
|
||||||
|
doc_count=select(func.count(col(KBDocument.id)))
|
||||||
|
.where(col(KBDocument.kb_id) == kb_id)
|
||||||
|
.scalar_subquery(),
|
||||||
|
chunk_count=chunk_cnt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.execute(update_stmt)
|
||||||
|
await session.commit()
|
||||||
361
astrbot/core/knowledge_base/kb_helper.py
Normal file
361
astrbot/core/knowledge_base/kb_helper.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||||
|
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||||
|
from astrbot.core.provider.manager import ProviderManager
|
||||||
|
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||||
|
|
||||||
|
from .chunking.base import BaseChunker
|
||||||
|
from .kb_db_sqlite import KBSQLiteDatabase
|
||||||
|
from .models import KBDocument, KBMedia, KnowledgeBase
|
||||||
|
from .parsers.util import select_parser
|
||||||
|
|
||||||
|
|
||||||
|
class KBHelper:
|
||||||
|
vec_db: BaseVecDB
|
||||||
|
kb: KnowledgeBase
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kb_db: KBSQLiteDatabase,
|
||||||
|
kb: KnowledgeBase,
|
||||||
|
provider_manager: ProviderManager,
|
||||||
|
kb_root_dir: str,
|
||||||
|
chunker: BaseChunker,
|
||||||
|
):
|
||||||
|
self.kb_db = kb_db
|
||||||
|
self.kb = kb
|
||||||
|
self.prov_mgr = provider_manager
|
||||||
|
self.kb_root_dir = kb_root_dir
|
||||||
|
self.chunker = chunker
|
||||||
|
|
||||||
|
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
|
||||||
|
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
|
||||||
|
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
|
||||||
|
|
||||||
|
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
await self._ensure_vec_db()
|
||||||
|
|
||||||
|
async def get_ep(self) -> EmbeddingProvider:
|
||||||
|
if not self.kb.embedding_provider_id:
|
||||||
|
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||||
|
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
|
||||||
|
self.kb.embedding_provider_id,
|
||||||
|
) # type: ignore
|
||||||
|
if not ep:
|
||||||
|
raise ValueError(
|
||||||
|
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider",
|
||||||
|
)
|
||||||
|
return ep
|
||||||
|
|
||||||
|
async def get_rp(self) -> RerankProvider | None:
|
||||||
|
if not self.kb.rerank_provider_id:
|
||||||
|
return None
|
||||||
|
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
|
||||||
|
self.kb.rerank_provider_id,
|
||||||
|
) # type: ignore
|
||||||
|
if not rp:
|
||||||
|
raise ValueError(
|
||||||
|
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider",
|
||||||
|
)
|
||||||
|
return rp
|
||||||
|
|
||||||
|
async def _ensure_vec_db(self) -> FaissVecDB:
|
||||||
|
if not self.kb.embedding_provider_id:
|
||||||
|
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||||
|
|
||||||
|
ep = await self.get_ep()
|
||||||
|
rp = await self.get_rp()
|
||||||
|
|
||||||
|
vec_db = FaissVecDB(
|
||||||
|
doc_store_path=str(self.kb_dir / "doc.db"),
|
||||||
|
index_store_path=str(self.kb_dir / "index.faiss"),
|
||||||
|
embedding_provider=ep,
|
||||||
|
rerank_provider=rp,
|
||||||
|
)
|
||||||
|
await vec_db.initialize()
|
||||||
|
self.vec_db = vec_db
|
||||||
|
return vec_db
|
||||||
|
|
||||||
|
async def delete_vec_db(self):
|
||||||
|
"""删除知识库的向量数据库和所有相关文件"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
await self.terminate()
|
||||||
|
if self.kb_dir.exists():
|
||||||
|
shutil.rmtree(self.kb_dir)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
if self.vec_db:
|
||||||
|
await self.vec_db.close()
|
||||||
|
|
||||||
|
async def upload_document(
|
||||||
|
self,
|
||||||
|
file_name: str,
|
||||||
|
file_content: bytes,
|
||||||
|
file_type: str,
|
||||||
|
chunk_size: int = 512,
|
||||||
|
chunk_overlap: int = 50,
|
||||||
|
batch_size: int = 32,
|
||||||
|
tasks_limit: int = 3,
|
||||||
|
max_retries: int = 3,
|
||||||
|
progress_callback=None,
|
||||||
|
) -> KBDocument:
|
||||||
|
"""上传并处理文档(带原子性保证和失败清理)
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 保存原始文件
|
||||||
|
2. 解析文档内容
|
||||||
|
3. 提取多媒体资源
|
||||||
|
4. 分块处理
|
||||||
|
5. 生成向量并存储
|
||||||
|
6. 保存元数据(事务)
|
||||||
|
7. 更新统计
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
||||||
|
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
|
||||||
|
- current: 当前进度
|
||||||
|
- total: 总数
|
||||||
|
|
||||||
|
"""
|
||||||
|
await self._ensure_vec_db()
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
media_paths: list[Path] = []
|
||||||
|
|
||||||
|
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
||||||
|
# async with aiofiles.open(file_path, "wb") as f:
|
||||||
|
# await f.write(file_content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 阶段1: 解析文档
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("parsing", 0, 100)
|
||||||
|
|
||||||
|
parser = await select_parser(f".{file_type}")
|
||||||
|
parse_result = await parser.parse(file_content, file_name)
|
||||||
|
text_content = parse_result.text
|
||||||
|
media_items = parse_result.media
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("parsing", 100, 100)
|
||||||
|
|
||||||
|
# 保存媒体文件
|
||||||
|
saved_media = []
|
||||||
|
for media_item in media_items:
|
||||||
|
media = await self._save_media(
|
||||||
|
doc_id=doc_id,
|
||||||
|
media_type=media_item.media_type,
|
||||||
|
file_name=media_item.file_name,
|
||||||
|
content=media_item.content,
|
||||||
|
mime_type=media_item.mime_type,
|
||||||
|
)
|
||||||
|
saved_media.append(media)
|
||||||
|
media_paths.append(Path(media.file_path))
|
||||||
|
|
||||||
|
# 阶段2: 分块
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("chunking", 0, 100)
|
||||||
|
|
||||||
|
chunks_text = await self.chunker.chunk(
|
||||||
|
text_content,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
contents = []
|
||||||
|
metadatas = []
|
||||||
|
for idx, chunk_text in enumerate(chunks_text):
|
||||||
|
contents.append(chunk_text)
|
||||||
|
metadatas.append(
|
||||||
|
{
|
||||||
|
"kb_id": self.kb.kb_id,
|
||||||
|
"kb_doc_id": doc_id,
|
||||||
|
"chunk_index": idx,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("chunking", 100, 100)
|
||||||
|
|
||||||
|
# 阶段3: 生成向量(带进度回调)
|
||||||
|
async def embedding_progress_callback(current, total):
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("embedding", current, total)
|
||||||
|
|
||||||
|
await self.vec_db.insert_batch(
|
||||||
|
contents=contents,
|
||||||
|
metadatas=metadatas,
|
||||||
|
batch_size=batch_size,
|
||||||
|
tasks_limit=tasks_limit,
|
||||||
|
max_retries=max_retries,
|
||||||
|
progress_callback=embedding_progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存文档的元数据
|
||||||
|
doc = KBDocument(
|
||||||
|
doc_id=doc_id,
|
||||||
|
kb_id=self.kb.kb_id,
|
||||||
|
doc_name=file_name,
|
||||||
|
file_type=file_type,
|
||||||
|
file_size=len(file_content),
|
||||||
|
# file_path=str(file_path),
|
||||||
|
file_path="",
|
||||||
|
chunk_count=len(chunks_text),
|
||||||
|
media_count=0,
|
||||||
|
)
|
||||||
|
async with self.kb_db.get_db() as session:
|
||||||
|
async with session.begin():
|
||||||
|
session.add(doc)
|
||||||
|
for media in saved_media:
|
||||||
|
session.add(media)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
await session.refresh(doc)
|
||||||
|
|
||||||
|
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||||
|
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
|
||||||
|
await self.refresh_kb()
|
||||||
|
await self.refresh_document(doc_id)
|
||||||
|
return doc
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"上传文档失败: {e}")
|
||||||
|
# if file_path.exists():
|
||||||
|
# file_path.unlink()
|
||||||
|
|
||||||
|
for media_path in media_paths:
|
||||||
|
try:
|
||||||
|
if media_path.exists():
|
||||||
|
media_path.unlink()
|
||||||
|
except Exception as me:
|
||||||
|
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
||||||
|
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def list_documents(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> list[KBDocument]:
|
||||||
|
"""列出知识库的所有文档"""
|
||||||
|
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
async def get_document(self, doc_id: str) -> KBDocument | None:
|
||||||
|
"""获取单个文档"""
|
||||||
|
doc = await self.kb_db.get_document_by_id(doc_id)
|
||||||
|
return doc
|
||||||
|
|
||||||
|
async def delete_document(self, doc_id: str):
|
||||||
|
"""删除单个文档及其相关数据"""
|
||||||
|
await self.kb_db.delete_document_by_id(
|
||||||
|
doc_id=doc_id,
|
||||||
|
vec_db=self.vec_db, # type: ignore
|
||||||
|
)
|
||||||
|
await self.kb_db.update_kb_stats(
|
||||||
|
kb_id=self.kb.kb_id,
|
||||||
|
vec_db=self.vec_db, # type: ignore
|
||||||
|
)
|
||||||
|
await self.refresh_kb()
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str, doc_id: str):
|
||||||
|
"""删除单个文本块及其相关数据"""
|
||||||
|
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||||
|
await vec_db.delete(chunk_id)
|
||||||
|
await self.kb_db.update_kb_stats(
|
||||||
|
kb_id=self.kb.kb_id,
|
||||||
|
vec_db=self.vec_db, # type: ignore
|
||||||
|
)
|
||||||
|
await self.refresh_kb()
|
||||||
|
await self.refresh_document(doc_id)
|
||||||
|
|
||||||
|
async def refresh_kb(self):
|
||||||
|
if self.kb:
|
||||||
|
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
|
||||||
|
if kb:
|
||||||
|
self.kb = kb
|
||||||
|
|
||||||
|
async def refresh_document(self, doc_id: str) -> None:
|
||||||
|
"""更新文档的元数据"""
|
||||||
|
doc = await self.get_document(doc_id)
|
||||||
|
if not doc:
|
||||||
|
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
|
||||||
|
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
|
||||||
|
doc.chunk_count = chunk_count
|
||||||
|
async with self.kb_db.get_db() as session:
|
||||||
|
async with session.begin():
|
||||||
|
session.add(doc)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(doc)
|
||||||
|
|
||||||
|
async def get_chunks_by_doc_id(
|
||||||
|
self,
|
||||||
|
doc_id: str,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""获取文档的所有块及其元数据"""
|
||||||
|
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||||
|
chunks = await vec_db.document_storage.get_documents(
|
||||||
|
metadata_filters={"kb_doc_id": doc_id},
|
||||||
|
offset=offset,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
result = []
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_md = json.loads(chunk["metadata"])
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
"chunk_id": chunk["doc_id"],
|
||||||
|
"doc_id": chunk_md["kb_doc_id"],
|
||||||
|
"kb_id": chunk_md["kb_id"],
|
||||||
|
"chunk_index": chunk_md["chunk_index"],
|
||||||
|
"content": chunk["text"],
|
||||||
|
"char_count": len(chunk["text"]),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
|
||||||
|
"""获取文档的块数量"""
|
||||||
|
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||||
|
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def _save_media(
|
||||||
|
self,
|
||||||
|
doc_id: str,
|
||||||
|
media_type: str,
|
||||||
|
file_name: str,
|
||||||
|
content: bytes,
|
||||||
|
mime_type: str,
|
||||||
|
) -> KBMedia:
|
||||||
|
"""保存多媒体资源"""
|
||||||
|
media_id = str(uuid.uuid4())
|
||||||
|
ext = Path(file_name).suffix
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
async with aiofiles.open(file_path, "wb") as f:
|
||||||
|
await f.write(content)
|
||||||
|
|
||||||
|
media = KBMedia(
|
||||||
|
media_id=media_id,
|
||||||
|
doc_id=doc_id,
|
||||||
|
kb_id=self.kb.kb_id,
|
||||||
|
media_type=media_type,
|
||||||
|
file_name=file_name,
|
||||||
|
file_path=str(file_path),
|
||||||
|
file_size=len(content),
|
||||||
|
mime_type=mime_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return media
|
||||||
286
astrbot/core/knowledge_base/kb_mgr.py
Normal file
286
astrbot/core/knowledge_base/kb_mgr.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.provider.manager import ProviderManager
|
||||||
|
|
||||||
|
# from .chunking.fixed_size import FixedSizeChunker
|
||||||
|
from .chunking.recursive import RecursiveCharacterChunker
|
||||||
|
from .kb_db_sqlite import KBSQLiteDatabase
|
||||||
|
from .kb_helper import KBHelper
|
||||||
|
from .models import KnowledgeBase
|
||||||
|
from .retrieval.manager import RetrievalManager, RetrievalResult
|
||||||
|
from .retrieval.rank_fusion import RankFusion
|
||||||
|
from .retrieval.sparse_retriever import SparseRetriever
|
||||||
|
|
||||||
|
FILES_PATH = "data/knowledge_base"
|
||||||
|
DB_PATH = Path(FILES_PATH) / "kb.db"
|
||||||
|
"""Knowledge Base storage root directory"""
|
||||||
|
CHUNKER = RecursiveCharacterChunker()
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBaseManager:
|
||||||
|
kb_db: KBSQLiteDatabase
|
||||||
|
retrieval_manager: RetrievalManager
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_manager: ProviderManager,
|
||||||
|
):
|
||||||
|
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.provider_manager = provider_manager
|
||||||
|
self._session_deleted_callback_registered = False
|
||||||
|
|
||||||
|
self.kb_insts: dict[str, KBHelper] = {}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""初始化知识库模块"""
|
||||||
|
try:
|
||||||
|
logger.info("正在初始化知识库模块...")
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
await self._init_kb_database()
|
||||||
|
|
||||||
|
# 初始化检索管理器
|
||||||
|
sparse_retriever = SparseRetriever(self.kb_db)
|
||||||
|
rank_fusion = RankFusion(self.kb_db)
|
||||||
|
self.retrieval_manager = RetrievalManager(
|
||||||
|
sparse_retriever=sparse_retriever,
|
||||||
|
rank_fusion=rank_fusion,
|
||||||
|
kb_db=self.kb_db,
|
||||||
|
)
|
||||||
|
await self.load_kbs()
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"知识库模块导入失败: {e}")
|
||||||
|
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"知识库模块初始化失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def _init_kb_database(self):
|
||||||
|
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
|
||||||
|
await self.kb_db.initialize()
|
||||||
|
await self.kb_db.migrate_to_v1()
|
||||||
|
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
|
||||||
|
|
||||||
|
async def load_kbs(self):
|
||||||
|
"""加载所有知识库实例"""
|
||||||
|
kb_records = await self.kb_db.list_kbs()
|
||||||
|
for record in kb_records:
|
||||||
|
kb_helper = KBHelper(
|
||||||
|
kb_db=self.kb_db,
|
||||||
|
kb=record,
|
||||||
|
provider_manager=self.provider_manager,
|
||||||
|
kb_root_dir=FILES_PATH,
|
||||||
|
chunker=CHUNKER,
|
||||||
|
)
|
||||||
|
await kb_helper.initialize()
|
||||||
|
self.kb_insts[record.kb_id] = kb_helper
|
||||||
|
|
||||||
|
async def create_kb(
|
||||||
|
self,
|
||||||
|
kb_name: str,
|
||||||
|
description: str | None = None,
|
||||||
|
emoji: str | None = None,
|
||||||
|
embedding_provider_id: str | None = None,
|
||||||
|
rerank_provider_id: str | None = None,
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
chunk_overlap: int | None = None,
|
||||||
|
top_k_dense: int | None = None,
|
||||||
|
top_k_sparse: int | None = None,
|
||||||
|
top_m_final: int | None = None,
|
||||||
|
) -> KBHelper:
|
||||||
|
"""创建新的知识库实例"""
|
||||||
|
kb = KnowledgeBase(
|
||||||
|
kb_name=kb_name,
|
||||||
|
description=description,
|
||||||
|
emoji=emoji or "📚",
|
||||||
|
embedding_provider_id=embedding_provider_id,
|
||||||
|
rerank_provider_id=rerank_provider_id,
|
||||||
|
chunk_size=chunk_size if chunk_size is not None else 512,
|
||||||
|
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
||||||
|
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
||||||
|
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
||||||
|
top_m_final=top_m_final if top_m_final is not None else 5,
|
||||||
|
)
|
||||||
|
async with self.kb_db.get_db() as session:
|
||||||
|
session.add(kb)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(kb)
|
||||||
|
|
||||||
|
kb_helper = KBHelper(
|
||||||
|
kb_db=self.kb_db,
|
||||||
|
kb=kb,
|
||||||
|
provider_manager=self.provider_manager,
|
||||||
|
kb_root_dir=FILES_PATH,
|
||||||
|
chunker=CHUNKER,
|
||||||
|
)
|
||||||
|
await kb_helper.initialize()
|
||||||
|
self.kb_insts[kb.kb_id] = kb_helper
|
||||||
|
return kb_helper
|
||||||
|
|
||||||
|
async def get_kb(self, kb_id: str) -> KBHelper | None:
|
||||||
|
"""获取知识库实例"""
|
||||||
|
if kb_id in self.kb_insts:
|
||||||
|
return self.kb_insts[kb_id]
|
||||||
|
|
||||||
|
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
|
||||||
|
"""通过名称获取知识库实例"""
|
||||||
|
for kb_helper in self.kb_insts.values():
|
||||||
|
if kb_helper.kb.kb_name == kb_name:
|
||||||
|
return kb_helper
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_kb(self, kb_id: str) -> bool:
|
||||||
|
"""删除知识库实例"""
|
||||||
|
kb_helper = await self.get_kb(kb_id)
|
||||||
|
if not kb_helper:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await kb_helper.delete_vec_db()
|
||||||
|
async with self.kb_db.get_db() as session:
|
||||||
|
await session.delete(kb_helper.kb)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
self.kb_insts.pop(kb_id, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def list_kbs(self) -> list[KnowledgeBase]:
|
||||||
|
"""列出所有知识库实例"""
|
||||||
|
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
|
||||||
|
return kbs
|
||||||
|
|
||||||
|
async def update_kb(
|
||||||
|
self,
|
||||||
|
kb_id: str,
|
||||||
|
kb_name: str,
|
||||||
|
description: str | None = None,
|
||||||
|
emoji: str | None = None,
|
||||||
|
embedding_provider_id: str | None = None,
|
||||||
|
rerank_provider_id: str | None = None,
|
||||||
|
chunk_size: int | None = None,
|
||||||
|
chunk_overlap: int | None = None,
|
||||||
|
top_k_dense: int | None = None,
|
||||||
|
top_k_sparse: int | None = None,
|
||||||
|
top_m_final: int | None = None,
|
||||||
|
) -> KBHelper | None:
|
||||||
|
"""更新知识库实例"""
|
||||||
|
kb_helper = await self.get_kb(kb_id)
|
||||||
|
if not kb_helper:
|
||||||
|
return None
|
||||||
|
|
||||||
|
kb = kb_helper.kb
|
||||||
|
if kb_name is not None:
|
||||||
|
kb.kb_name = kb_name
|
||||||
|
if description is not None:
|
||||||
|
kb.description = description
|
||||||
|
if emoji is not None:
|
||||||
|
kb.emoji = emoji
|
||||||
|
if embedding_provider_id is not None:
|
||||||
|
kb.embedding_provider_id = embedding_provider_id
|
||||||
|
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
|
||||||
|
if chunk_size is not None:
|
||||||
|
kb.chunk_size = chunk_size
|
||||||
|
if chunk_overlap is not None:
|
||||||
|
kb.chunk_overlap = chunk_overlap
|
||||||
|
if top_k_dense is not None:
|
||||||
|
kb.top_k_dense = top_k_dense
|
||||||
|
if top_k_sparse is not None:
|
||||||
|
kb.top_k_sparse = top_k_sparse
|
||||||
|
if top_m_final is not None:
|
||||||
|
kb.top_m_final = top_m_final
|
||||||
|
async with self.kb_db.get_db() as session:
|
||||||
|
session.add(kb)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(kb)
|
||||||
|
|
||||||
|
return kb_helper
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_names: list[str],
|
||||||
|
top_k_fusion: int = 20,
|
||||||
|
top_m_final: int = 5,
|
||||||
|
) -> dict | None:
|
||||||
|
"""从指定知识库中检索相关内容"""
|
||||||
|
kb_ids = []
|
||||||
|
kb_id_helper_map = {}
|
||||||
|
for kb_name in kb_names:
|
||||||
|
if kb_helper := await self.get_kb_by_name(kb_name):
|
||||||
|
kb_ids.append(kb_helper.kb.kb_id)
|
||||||
|
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
|
||||||
|
|
||||||
|
if not kb_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
results = await self.retrieval_manager.retrieve(
|
||||||
|
query=query,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
kb_id_helper_map=kb_id_helper_map,
|
||||||
|
top_k_fusion=top_k_fusion,
|
||||||
|
top_m_final=top_m_final,
|
||||||
|
)
|
||||||
|
if not results:
|
||||||
|
return None
|
||||||
|
|
||||||
|
context_text = self._format_context(results)
|
||||||
|
|
||||||
|
results_dict = [
|
||||||
|
{
|
||||||
|
"chunk_id": r.chunk_id,
|
||||||
|
"doc_id": r.doc_id,
|
||||||
|
"kb_id": r.kb_id,
|
||||||
|
"kb_name": r.kb_name,
|
||||||
|
"doc_name": r.doc_name,
|
||||||
|
"chunk_index": r.metadata.get("chunk_index", 0),
|
||||||
|
"content": r.content,
|
||||||
|
"score": r.score,
|
||||||
|
"char_count": r.metadata.get("char_count", 0),
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"context_text": context_text,
|
||||||
|
"results": results_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _format_context(self, results: list[RetrievalResult]) -> str:
|
||||||
|
"""格式化知识上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: 检索结果列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化的上下文文本
|
||||||
|
|
||||||
|
"""
|
||||||
|
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
lines.append(f"【知识 {i}】")
|
||||||
|
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
||||||
|
lines.append(f"内容: {result.content}")
|
||||||
|
lines.append(f"相关度: {result.score:.2f}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
"""终止所有知识库实例,关闭数据库连接"""
|
||||||
|
for kb_id, kb_helper in self.kb_insts.items():
|
||||||
|
try:
|
||||||
|
await kb_helper.terminate()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"关闭知识库 {kb_id} 失败: {e}")
|
||||||
|
|
||||||
|
self.kb_insts.clear()
|
||||||
|
|
||||||
|
# 关闭元数据数据库
|
||||||
|
if hasattr(self, "kb_db") and self.kb_db:
|
||||||
|
try:
|
||||||
|
await self.kb_db.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"关闭知识库元数据数据库失败: {e}")
|
||||||
120
astrbot/core/knowledge_base/models.py
Normal file
120
astrbot/core/knowledge_base/models.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKBModel(SQLModel, table=False):
|
||||||
|
metadata = MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBase(BaseKBModel, table=True):
|
||||||
|
"""知识库表
|
||||||
|
|
||||||
|
存储知识库的基本信息和统计数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "knowledge_bases" # type: ignore
|
||||||
|
|
||||||
|
id: int | None = Field(
|
||||||
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
kb_id: str = Field(
|
||||||
|
max_length=36,
|
||||||
|
nullable=False,
|
||||||
|
unique=True,
|
||||||
|
default_factory=lambda: str(uuid.uuid4()),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
kb_name: str = Field(max_length=100, nullable=False)
|
||||||
|
description: str | None = Field(default=None, sa_type=Text)
|
||||||
|
emoji: str | None = Field(default="📚", max_length=10)
|
||||||
|
embedding_provider_id: str | None = Field(default=None, max_length=100)
|
||||||
|
rerank_provider_id: str | None = Field(default=None, max_length=100)
|
||||||
|
# 分块配置参数
|
||||||
|
chunk_size: int | None = Field(default=512, nullable=True)
|
||||||
|
chunk_overlap: int | None = Field(default=50, nullable=True)
|
||||||
|
# 检索配置参数
|
||||||
|
top_k_dense: int | None = Field(default=50, nullable=True)
|
||||||
|
top_k_sparse: int | None = Field(default=50, nullable=True)
|
||||||
|
top_m_final: int | None = Field(default=5, nullable=True)
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||||
|
)
|
||||||
|
doc_count: int = Field(default=0, nullable=False)
|
||||||
|
chunk_count: int = Field(default=0, nullable=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"kb_name",
|
||||||
|
name="uix_kb_name",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KBDocument(BaseKBModel, table=True):
|
||||||
|
"""文档表
|
||||||
|
|
||||||
|
存储上传到知识库的文档元数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "kb_documents" # type: ignore
|
||||||
|
|
||||||
|
id: int | None = Field(
|
||||||
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
doc_id: str = Field(
|
||||||
|
max_length=36,
|
||||||
|
nullable=False,
|
||||||
|
unique=True,
|
||||||
|
default_factory=lambda: str(uuid.uuid4()),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||||
|
doc_name: str = Field(max_length=255, nullable=False)
|
||||||
|
file_type: str = Field(max_length=20, nullable=False)
|
||||||
|
file_size: int = Field(nullable=False)
|
||||||
|
file_path: str = Field(max_length=512, nullable=False)
|
||||||
|
chunk_count: int = Field(default=0, nullable=False)
|
||||||
|
media_count: int = Field(default=0, nullable=False)
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KBMedia(BaseKBModel, table=True):
|
||||||
|
"""多媒体资源表
|
||||||
|
|
||||||
|
存储从文档中提取的图片、视频等多媒体资源。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "kb_media" # type: ignore
|
||||||
|
|
||||||
|
id: int | None = Field(
|
||||||
|
primary_key=True,
|
||||||
|
sa_column_kwargs={"autoincrement": True},
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
media_id: str = Field(
|
||||||
|
max_length=36,
|
||||||
|
nullable=False,
|
||||||
|
unique=True,
|
||||||
|
default_factory=lambda: str(uuid.uuid4()),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
doc_id: str = Field(max_length=36, nullable=False, index=True)
|
||||||
|
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||||
|
media_type: str = Field(max_length=20, nullable=False)
|
||||||
|
file_name: str = Field(max_length=255, nullable=False)
|
||||||
|
file_path: str = Field(max_length=512, nullable=False)
|
||||||
|
file_size: int = Field(nullable=False)
|
||||||
|
mime_type: str = Field(max_length=100, nullable=False)
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
13
astrbot/core/knowledge_base/parsers/__init__.py
Normal file
13
astrbot/core/knowledge_base/parsers/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""文档解析器模块"""
|
||||||
|
|
||||||
|
from .base import BaseParser, MediaItem, ParseResult
|
||||||
|
from .pdf_parser import PDFParser
|
||||||
|
from .text_parser import TextParser
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseParser",
|
||||||
|
"MediaItem",
|
||||||
|
"PDFParser",
|
||||||
|
"ParseResult",
|
||||||
|
"TextParser",
|
||||||
|
]
|
||||||
51
astrbot/core/knowledge_base/parsers/base.py
Normal file
51
astrbot/core/knowledge_base/parsers/base.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""文档解析器基类和数据结构
|
||||||
|
|
||||||
|
定义了文档解析器的抽象接口和相关数据类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MediaItem:
|
||||||
|
"""多媒体项
|
||||||
|
|
||||||
|
表示从文档中提取的多媒体资源。
|
||||||
|
"""
|
||||||
|
|
||||||
|
media_type: str # image, video
|
||||||
|
file_name: str
|
||||||
|
content: bytes
|
||||||
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParseResult:
|
||||||
|
"""解析结果
|
||||||
|
|
||||||
|
包含解析后的文本内容和提取的多媒体资源。
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
media: list[MediaItem]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseParser(ABC):
|
||||||
|
"""文档解析器基类
|
||||||
|
|
||||||
|
所有文档解析器都应该继承此类并实现 parse 方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||||
|
"""解析文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content: 文件内容
|
||||||
|
file_name: 文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParseResult: 解析结果
|
||||||
|
|
||||||
|
"""
|
||||||
26
astrbot/core/knowledge_base/parsers/markitdown_parser.py
Normal file
26
astrbot/core/knowledge_base/parsers/markitdown_parser.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
from markitdown_no_magika import MarkItDown, StreamInfo
|
||||||
|
|
||||||
|
from astrbot.core.knowledge_base.parsers.base import (
|
||||||
|
BaseParser,
|
||||||
|
ParseResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MarkitdownParser(BaseParser):
|
||||||
|
"""解析 docx, xls, xlsx 格式"""
|
||||||
|
|
||||||
|
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||||
|
md = MarkItDown(enable_plugins=False)
|
||||||
|
bio = io.BytesIO(file_content)
|
||||||
|
stream_info = StreamInfo(
|
||||||
|
extension=os.path.splitext(file_name)[1].lower(),
|
||||||
|
filename=file_name,
|
||||||
|
)
|
||||||
|
result = md.convert(bio, stream_info=stream_info)
|
||||||
|
return ParseResult(
|
||||||
|
text=result.markdown,
|
||||||
|
media=[],
|
||||||
|
)
|
||||||
101
astrbot/core/knowledge_base/parsers/pdf_parser.py
Normal file
101
astrbot/core/knowledge_base/parsers/pdf_parser.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""PDF 文件解析器
|
||||||
|
|
||||||
|
支持解析 PDF 文件中的文本和图片资源。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
|
from pypdf import PdfReader
|
||||||
|
|
||||||
|
from astrbot.core.knowledge_base.parsers.base import (
|
||||||
|
BaseParser,
|
||||||
|
MediaItem,
|
||||||
|
ParseResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PDFParser(BaseParser):
|
||||||
|
"""PDF 文档解析器
|
||||||
|
|
||||||
|
提取 PDF 中的文本内容和嵌入的图片资源。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||||
|
"""解析 PDF 文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content: 文件内容
|
||||||
|
file_name: 文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParseResult: 包含文本和图片的解析结果
|
||||||
|
|
||||||
|
"""
|
||||||
|
pdf_file = io.BytesIO(file_content)
|
||||||
|
reader = PdfReader(pdf_file)
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
media_items = []
|
||||||
|
|
||||||
|
# 提取文本
|
||||||
|
for page in reader.pages:
|
||||||
|
text = page.extract_text()
|
||||||
|
if text:
|
||||||
|
text_parts.append(text)
|
||||||
|
|
||||||
|
# 提取图片
|
||||||
|
image_counter = 0
|
||||||
|
for page_num, page in enumerate(reader.pages):
|
||||||
|
try:
|
||||||
|
# 安全检查 Resources
|
||||||
|
if "/Resources" not in page:
|
||||||
|
continue
|
||||||
|
|
||||||
|
resources = page["/Resources"]
|
||||||
|
if not resources or "/XObject" not in resources: # type: ignore
|
||||||
|
continue
|
||||||
|
|
||||||
|
xobjects = resources["/XObject"].get_object() # type: ignore
|
||||||
|
if not xobjects:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for obj_name in xobjects:
|
||||||
|
try:
|
||||||
|
obj = xobjects[obj_name]
|
||||||
|
|
||||||
|
if obj.get("/Subtype") != "/Image":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 提取图片数据
|
||||||
|
image_data = obj.get_data()
|
||||||
|
|
||||||
|
# 确定格式
|
||||||
|
filter_type = obj.get("/Filter", "")
|
||||||
|
if filter_type == "/DCTDecode":
|
||||||
|
ext = "jpg"
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
elif filter_type == "/FlateDecode":
|
||||||
|
ext = "png"
|
||||||
|
mime_type = "image/png"
|
||||||
|
else:
|
||||||
|
ext = "png"
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
image_counter += 1
|
||||||
|
media_items.append(
|
||||||
|
MediaItem(
|
||||||
|
media_type="image",
|
||||||
|
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
|
||||||
|
content=image_data,
|
||||||
|
mime_type=mime_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# 单个图片提取失败不影响整体
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
# 页面处理失败不影响其他页面
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_text = "\n\n".join(text_parts)
|
||||||
|
return ParseResult(text=full_text, media=media_items)
|
||||||
42
astrbot/core/knowledge_base/parsers/text_parser.py
Normal file
42
astrbot/core/knowledge_base/parsers/text_parser.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""文本文件解析器
|
||||||
|
|
||||||
|
支持解析 TXT 和 Markdown 文件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
|
||||||
|
|
||||||
|
|
||||||
|
class TextParser(BaseParser):
|
||||||
|
"""TXT/MD 文本解析器
|
||||||
|
|
||||||
|
支持多种字符编码的自动检测。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
|
||||||
|
"""解析文本文件
|
||||||
|
|
||||||
|
尝试使用多种编码解析文件内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content: 文件内容
|
||||||
|
file_name: 文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParseResult: 解析结果,不包含多媒体资源
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果无法解码文件
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 尝试多种编码
|
||||||
|
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
|
||||||
|
try:
|
||||||
|
text = file_content.decode(encoding)
|
||||||
|
break
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"无法解码文件: {file_name}")
|
||||||
|
|
||||||
|
# 文本文件无多媒体资源
|
||||||
|
return ParseResult(text=text, media=[])
|
||||||
13
astrbot/core/knowledge_base/parsers/util.py
Normal file
13
astrbot/core/knowledge_base/parsers/util.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from .base import BaseParser
|
||||||
|
|
||||||
|
|
||||||
|
async def select_parser(ext: str) -> BaseParser:
|
||||||
|
if ext in {".md", ".txt", ".markdown", ".xlsx", ".docx", ".xls"}:
|
||||||
|
from .markitdown_parser import MarkitdownParser
|
||||||
|
|
||||||
|
return MarkitdownParser()
|
||||||
|
if ext == ".pdf":
|
||||||
|
from .pdf_parser import PDFParser
|
||||||
|
|
||||||
|
return PDFParser()
|
||||||
|
raise ValueError(f"暂时不支持的文件格式: {ext}")
|
||||||
14
astrbot/core/knowledge_base/retrieval/__init__.py
Normal file
14
astrbot/core/knowledge_base/retrieval/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""检索模块"""
|
||||||
|
|
||||||
|
from .manager import RetrievalManager, RetrievalResult
|
||||||
|
from .rank_fusion import FusedResult, RankFusion
|
||||||
|
from .sparse_retriever import SparseResult, SparseRetriever
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FusedResult",
|
||||||
|
"RankFusion",
|
||||||
|
"RetrievalManager",
|
||||||
|
"RetrievalResult",
|
||||||
|
"SparseResult",
|
||||||
|
"SparseRetriever",
|
||||||
|
]
|
||||||
767
astrbot/core/knowledge_base/retrieval/hit_stopwords.txt
Normal file
767
astrbot/core/knowledge_base/retrieval/hit_stopwords.txt
Normal file
@@ -0,0 +1,767 @@
|
|||||||
|
———
|
||||||
|
》),
|
||||||
|
)÷(1-
|
||||||
|
”,
|
||||||
|
)、
|
||||||
|
=(
|
||||||
|
:
|
||||||
|
→
|
||||||
|
℃
|
||||||
|
&
|
||||||
|
*
|
||||||
|
一一
|
||||||
|
~~~~
|
||||||
|
’
|
||||||
|
.
|
||||||
|
『
|
||||||
|
.一
|
||||||
|
./
|
||||||
|
--
|
||||||
|
』
|
||||||
|
=″
|
||||||
|
【
|
||||||
|
[*]
|
||||||
|
}>
|
||||||
|
[⑤]]
|
||||||
|
[①D]
|
||||||
|
c]
|
||||||
|
ng昉
|
||||||
|
*
|
||||||
|
//
|
||||||
|
[
|
||||||
|
]
|
||||||
|
[②e]
|
||||||
|
[②g]
|
||||||
|
={
|
||||||
|
}
|
||||||
|
,也
|
||||||
|
‘
|
||||||
|
A
|
||||||
|
[①⑥]
|
||||||
|
[②B]
|
||||||
|
[①a]
|
||||||
|
[④a]
|
||||||
|
[①③]
|
||||||
|
[③h]
|
||||||
|
③]
|
||||||
|
1.
|
||||||
|
--
|
||||||
|
[②b]
|
||||||
|
’‘
|
||||||
|
×××
|
||||||
|
[①⑧]
|
||||||
|
0:2
|
||||||
|
=[
|
||||||
|
[⑤b]
|
||||||
|
[②c]
|
||||||
|
[④b]
|
||||||
|
[②③]
|
||||||
|
[③a]
|
||||||
|
[④c]
|
||||||
|
[①⑤]
|
||||||
|
[①⑦]
|
||||||
|
[①g]
|
||||||
|
∈[
|
||||||
|
[①⑨]
|
||||||
|
[①④]
|
||||||
|
[①c]
|
||||||
|
[②f]
|
||||||
|
[②⑧]
|
||||||
|
[②①]
|
||||||
|
[①C]
|
||||||
|
[③c]
|
||||||
|
[③g]
|
||||||
|
[②⑤]
|
||||||
|
[②②]
|
||||||
|
一.
|
||||||
|
[①h]
|
||||||
|
.数
|
||||||
|
[]
|
||||||
|
[①B]
|
||||||
|
数/
|
||||||
|
[①i]
|
||||||
|
[③e]
|
||||||
|
[①①]
|
||||||
|
[④d]
|
||||||
|
[④e]
|
||||||
|
[③b]
|
||||||
|
[⑤a]
|
||||||
|
[①A]
|
||||||
|
[②⑧]
|
||||||
|
[②⑦]
|
||||||
|
[①d]
|
||||||
|
[②j]
|
||||||
|
〕〔
|
||||||
|
][
|
||||||
|
://
|
||||||
|
′∈
|
||||||
|
[②④
|
||||||
|
[⑤e]
|
||||||
|
12%
|
||||||
|
b]
|
||||||
|
...
|
||||||
|
...................
|
||||||
|
…………………………………………………③
|
||||||
|
ZXFITL
|
||||||
|
[③F]
|
||||||
|
」
|
||||||
|
[①o]
|
||||||
|
]∧′=[
|
||||||
|
∪φ∈
|
||||||
|
′|
|
||||||
|
{-
|
||||||
|
②c
|
||||||
|
}
|
||||||
|
[③①]
|
||||||
|
R.L.
|
||||||
|
[①E]
|
||||||
|
Ψ
|
||||||
|
-[*]-
|
||||||
|
↑
|
||||||
|
.日
|
||||||
|
[②d]
|
||||||
|
[②
|
||||||
|
[②⑦]
|
||||||
|
[②②]
|
||||||
|
[③e]
|
||||||
|
[①i]
|
||||||
|
[①B]
|
||||||
|
[①h]
|
||||||
|
[①d]
|
||||||
|
[①g]
|
||||||
|
[①②]
|
||||||
|
[②a]
|
||||||
|
f]
|
||||||
|
[⑩]
|
||||||
|
a]
|
||||||
|
[①e]
|
||||||
|
[②h]
|
||||||
|
[②⑥]
|
||||||
|
[③d]
|
||||||
|
[②⑩]
|
||||||
|
e]
|
||||||
|
〉
|
||||||
|
】
|
||||||
|
元/吨
|
||||||
|
[②⑩]
|
||||||
|
2.3%
|
||||||
|
5:0
|
||||||
|
[①]
|
||||||
|
::
|
||||||
|
[②]
|
||||||
|
[③]
|
||||||
|
[④]
|
||||||
|
[⑤]
|
||||||
|
[⑥]
|
||||||
|
[⑦]
|
||||||
|
[⑧]
|
||||||
|
[⑨]
|
||||||
|
……
|
||||||
|
——
|
||||||
|
?
|
||||||
|
、
|
||||||
|
。
|
||||||
|
“
|
||||||
|
”
|
||||||
|
《
|
||||||
|
》
|
||||||
|
!
|
||||||
|
,
|
||||||
|
:
|
||||||
|
;
|
||||||
|
?
|
||||||
|
.
|
||||||
|
,
|
||||||
|
.
|
||||||
|
'
|
||||||
|
?
|
||||||
|
·
|
||||||
|
———
|
||||||
|
──
|
||||||
|
?
|
||||||
|
—
|
||||||
|
<
|
||||||
|
>
|
||||||
|
(
|
||||||
|
)
|
||||||
|
〔
|
||||||
|
〕
|
||||||
|
[
|
||||||
|
]
|
||||||
|
(
|
||||||
|
)
|
||||||
|
-
|
||||||
|
+
|
||||||
|
~
|
||||||
|
×
|
||||||
|
/
|
||||||
|
/
|
||||||
|
①
|
||||||
|
②
|
||||||
|
③
|
||||||
|
④
|
||||||
|
⑤
|
||||||
|
⑥
|
||||||
|
⑦
|
||||||
|
⑧
|
||||||
|
⑨
|
||||||
|
⑩
|
||||||
|
Ⅲ
|
||||||
|
В
|
||||||
|
"
|
||||||
|
;
|
||||||
|
#
|
||||||
|
@
|
||||||
|
γ
|
||||||
|
μ
|
||||||
|
φ
|
||||||
|
φ.
|
||||||
|
×
|
||||||
|
Δ
|
||||||
|
■
|
||||||
|
▲
|
||||||
|
sub
|
||||||
|
exp
|
||||||
|
sup
|
||||||
|
sub
|
||||||
|
Lex
|
||||||
|
#
|
||||||
|
%
|
||||||
|
&
|
||||||
|
'
|
||||||
|
+
|
||||||
|
+ξ
|
||||||
|
++
|
||||||
|
-
|
||||||
|
-β
|
||||||
|
<
|
||||||
|
<±
|
||||||
|
<Δ
|
||||||
|
<λ
|
||||||
|
<φ
|
||||||
|
<<
|
||||||
|
=
|
||||||
|
=
|
||||||
|
=☆
|
||||||
|
=-
|
||||||
|
>
|
||||||
|
>λ
|
||||||
|
_
|
||||||
|
~±
|
||||||
|
~+
|
||||||
|
[⑤f]
|
||||||
|
[⑤d]
|
||||||
|
[②i]
|
||||||
|
≈
|
||||||
|
[②G]
|
||||||
|
[①f]
|
||||||
|
LI
|
||||||
|
㈧
|
||||||
|
[-
|
||||||
|
......
|
||||||
|
〉
|
||||||
|
[③⑩]
|
||||||
|
第二
|
||||||
|
一番
|
||||||
|
一直
|
||||||
|
一个
|
||||||
|
一些
|
||||||
|
许多
|
||||||
|
种
|
||||||
|
有的是
|
||||||
|
也就是说
|
||||||
|
末##末
|
||||||
|
啊
|
||||||
|
阿
|
||||||
|
哎
|
||||||
|
哎呀
|
||||||
|
哎哟
|
||||||
|
唉
|
||||||
|
俺
|
||||||
|
俺们
|
||||||
|
按
|
||||||
|
按照
|
||||||
|
吧
|
||||||
|
吧哒
|
||||||
|
把
|
||||||
|
罢了
|
||||||
|
被
|
||||||
|
本
|
||||||
|
本着
|
||||||
|
比
|
||||||
|
比方
|
||||||
|
比如
|
||||||
|
鄙人
|
||||||
|
彼
|
||||||
|
彼此
|
||||||
|
边
|
||||||
|
别
|
||||||
|
别的
|
||||||
|
别说
|
||||||
|
并
|
||||||
|
并且
|
||||||
|
不比
|
||||||
|
不成
|
||||||
|
不单
|
||||||
|
不但
|
||||||
|
不独
|
||||||
|
不管
|
||||||
|
不光
|
||||||
|
不过
|
||||||
|
不仅
|
||||||
|
不拘
|
||||||
|
不论
|
||||||
|
不怕
|
||||||
|
不然
|
||||||
|
不如
|
||||||
|
不特
|
||||||
|
不惟
|
||||||
|
不问
|
||||||
|
不只
|
||||||
|
朝
|
||||||
|
朝着
|
||||||
|
趁
|
||||||
|
趁着
|
||||||
|
乘
|
||||||
|
冲
|
||||||
|
除
|
||||||
|
除此之外
|
||||||
|
除非
|
||||||
|
除了
|
||||||
|
此
|
||||||
|
此间
|
||||||
|
此外
|
||||||
|
从
|
||||||
|
从而
|
||||||
|
打
|
||||||
|
待
|
||||||
|
但
|
||||||
|
但是
|
||||||
|
当
|
||||||
|
当着
|
||||||
|
到
|
||||||
|
得
|
||||||
|
的
|
||||||
|
的话
|
||||||
|
等
|
||||||
|
等等
|
||||||
|
地
|
||||||
|
第
|
||||||
|
叮咚
|
||||||
|
对
|
||||||
|
对于
|
||||||
|
多
|
||||||
|
多少
|
||||||
|
而
|
||||||
|
而况
|
||||||
|
而且
|
||||||
|
而是
|
||||||
|
而外
|
||||||
|
而言
|
||||||
|
而已
|
||||||
|
尔后
|
||||||
|
反过来
|
||||||
|
反过来说
|
||||||
|
反之
|
||||||
|
非但
|
||||||
|
非徒
|
||||||
|
否则
|
||||||
|
嘎
|
||||||
|
嘎登
|
||||||
|
该
|
||||||
|
赶
|
||||||
|
个
|
||||||
|
各
|
||||||
|
各个
|
||||||
|
各位
|
||||||
|
各种
|
||||||
|
各自
|
||||||
|
给
|
||||||
|
根据
|
||||||
|
跟
|
||||||
|
故
|
||||||
|
故此
|
||||||
|
固然
|
||||||
|
关于
|
||||||
|
管
|
||||||
|
归
|
||||||
|
果然
|
||||||
|
果真
|
||||||
|
过
|
||||||
|
哈
|
||||||
|
哈哈
|
||||||
|
呵
|
||||||
|
和
|
||||||
|
何
|
||||||
|
何处
|
||||||
|
何况
|
||||||
|
何时
|
||||||
|
嘿
|
||||||
|
哼
|
||||||
|
哼唷
|
||||||
|
呼哧
|
||||||
|
乎
|
||||||
|
哗
|
||||||
|
还是
|
||||||
|
还有
|
||||||
|
换句话说
|
||||||
|
换言之
|
||||||
|
或
|
||||||
|
或是
|
||||||
|
或者
|
||||||
|
极了
|
||||||
|
及
|
||||||
|
及其
|
||||||
|
及至
|
||||||
|
即
|
||||||
|
即便
|
||||||
|
即或
|
||||||
|
即令
|
||||||
|
即若
|
||||||
|
即使
|
||||||
|
几
|
||||||
|
几时
|
||||||
|
己
|
||||||
|
既
|
||||||
|
既然
|
||||||
|
既是
|
||||||
|
继而
|
||||||
|
加之
|
||||||
|
假如
|
||||||
|
假若
|
||||||
|
假使
|
||||||
|
鉴于
|
||||||
|
将
|
||||||
|
较
|
||||||
|
较之
|
||||||
|
叫
|
||||||
|
接着
|
||||||
|
结果
|
||||||
|
借
|
||||||
|
紧接着
|
||||||
|
进而
|
||||||
|
尽
|
||||||
|
尽管
|
||||||
|
经
|
||||||
|
经过
|
||||||
|
就
|
||||||
|
就是
|
||||||
|
就是说
|
||||||
|
据
|
||||||
|
具体地说
|
||||||
|
具体说来
|
||||||
|
开始
|
||||||
|
开外
|
||||||
|
靠
|
||||||
|
咳
|
||||||
|
可
|
||||||
|
可见
|
||||||
|
可是
|
||||||
|
可以
|
||||||
|
况且
|
||||||
|
啦
|
||||||
|
来
|
||||||
|
来着
|
||||||
|
离
|
||||||
|
例如
|
||||||
|
哩
|
||||||
|
连
|
||||||
|
连同
|
||||||
|
两者
|
||||||
|
了
|
||||||
|
临
|
||||||
|
另
|
||||||
|
另外
|
||||||
|
另一方面
|
||||||
|
论
|
||||||
|
嘛
|
||||||
|
吗
|
||||||
|
慢说
|
||||||
|
漫说
|
||||||
|
冒
|
||||||
|
么
|
||||||
|
每
|
||||||
|
每当
|
||||||
|
们
|
||||||
|
莫若
|
||||||
|
某
|
||||||
|
某个
|
||||||
|
某些
|
||||||
|
拿
|
||||||
|
哪
|
||||||
|
哪边
|
||||||
|
哪儿
|
||||||
|
哪个
|
||||||
|
哪里
|
||||||
|
哪年
|
||||||
|
哪怕
|
||||||
|
哪天
|
||||||
|
哪些
|
||||||
|
哪样
|
||||||
|
那
|
||||||
|
那边
|
||||||
|
那儿
|
||||||
|
那个
|
||||||
|
那会儿
|
||||||
|
那里
|
||||||
|
那么
|
||||||
|
那么些
|
||||||
|
那么样
|
||||||
|
那时
|
||||||
|
那些
|
||||||
|
那样
|
||||||
|
乃
|
||||||
|
乃至
|
||||||
|
呢
|
||||||
|
能
|
||||||
|
你
|
||||||
|
你们
|
||||||
|
您
|
||||||
|
宁
|
||||||
|
宁可
|
||||||
|
宁肯
|
||||||
|
宁愿
|
||||||
|
哦
|
||||||
|
呕
|
||||||
|
啪达
|
||||||
|
旁人
|
||||||
|
呸
|
||||||
|
凭
|
||||||
|
凭借
|
||||||
|
其
|
||||||
|
其次
|
||||||
|
其二
|
||||||
|
其他
|
||||||
|
其它
|
||||||
|
其一
|
||||||
|
其余
|
||||||
|
其中
|
||||||
|
起
|
||||||
|
起见
|
||||||
|
起见
|
||||||
|
岂但
|
||||||
|
恰恰相反
|
||||||
|
前后
|
||||||
|
前者
|
||||||
|
且
|
||||||
|
然而
|
||||||
|
然后
|
||||||
|
然则
|
||||||
|
让
|
||||||
|
人家
|
||||||
|
任
|
||||||
|
任何
|
||||||
|
任凭
|
||||||
|
如
|
||||||
|
如此
|
||||||
|
如果
|
||||||
|
如何
|
||||||
|
如其
|
||||||
|
如若
|
||||||
|
如上所述
|
||||||
|
若
|
||||||
|
若非
|
||||||
|
若是
|
||||||
|
啥
|
||||||
|
上下
|
||||||
|
尚且
|
||||||
|
设若
|
||||||
|
设使
|
||||||
|
甚而
|
||||||
|
甚么
|
||||||
|
甚至
|
||||||
|
省得
|
||||||
|
时候
|
||||||
|
什么
|
||||||
|
什么样
|
||||||
|
使得
|
||||||
|
是
|
||||||
|
是的
|
||||||
|
首先
|
||||||
|
谁
|
||||||
|
谁知
|
||||||
|
顺
|
||||||
|
顺着
|
||||||
|
似的
|
||||||
|
虽
|
||||||
|
虽然
|
||||||
|
虽说
|
||||||
|
虽则
|
||||||
|
随
|
||||||
|
随着
|
||||||
|
所
|
||||||
|
所以
|
||||||
|
他
|
||||||
|
他们
|
||||||
|
他人
|
||||||
|
它
|
||||||
|
它们
|
||||||
|
她
|
||||||
|
她们
|
||||||
|
倘
|
||||||
|
倘或
|
||||||
|
倘然
|
||||||
|
倘若
|
||||||
|
倘使
|
||||||
|
腾
|
||||||
|
替
|
||||||
|
通过
|
||||||
|
同
|
||||||
|
同时
|
||||||
|
哇
|
||||||
|
万一
|
||||||
|
往
|
||||||
|
望
|
||||||
|
为
|
||||||
|
为何
|
||||||
|
为了
|
||||||
|
为什么
|
||||||
|
为着
|
||||||
|
喂
|
||||||
|
嗡嗡
|
||||||
|
我
|
||||||
|
我们
|
||||||
|
呜
|
||||||
|
呜呼
|
||||||
|
乌乎
|
||||||
|
无论
|
||||||
|
无宁
|
||||||
|
毋宁
|
||||||
|
嘻
|
||||||
|
吓
|
||||||
|
相对而言
|
||||||
|
像
|
||||||
|
向
|
||||||
|
向着
|
||||||
|
嘘
|
||||||
|
呀
|
||||||
|
焉
|
||||||
|
沿
|
||||||
|
沿着
|
||||||
|
要
|
||||||
|
要不
|
||||||
|
要不然
|
||||||
|
要不是
|
||||||
|
要么
|
||||||
|
要是
|
||||||
|
也
|
||||||
|
也罢
|
||||||
|
也好
|
||||||
|
一
|
||||||
|
一般
|
||||||
|
一旦
|
||||||
|
一方面
|
||||||
|
一来
|
||||||
|
一切
|
||||||
|
一样
|
||||||
|
一则
|
||||||
|
依
|
||||||
|
依照
|
||||||
|
矣
|
||||||
|
以
|
||||||
|
以便
|
||||||
|
以及
|
||||||
|
以免
|
||||||
|
以至
|
||||||
|
以至于
|
||||||
|
以致
|
||||||
|
抑或
|
||||||
|
因
|
||||||
|
因此
|
||||||
|
因而
|
||||||
|
因为
|
||||||
|
哟
|
||||||
|
用
|
||||||
|
由
|
||||||
|
由此可见
|
||||||
|
由于
|
||||||
|
有
|
||||||
|
有的
|
||||||
|
有关
|
||||||
|
有些
|
||||||
|
又
|
||||||
|
于
|
||||||
|
于是
|
||||||
|
于是乎
|
||||||
|
与
|
||||||
|
与此同时
|
||||||
|
与否
|
||||||
|
与其
|
||||||
|
越是
|
||||||
|
云云
|
||||||
|
哉
|
||||||
|
再说
|
||||||
|
再者
|
||||||
|
在
|
||||||
|
在下
|
||||||
|
咱
|
||||||
|
咱们
|
||||||
|
则
|
||||||
|
怎
|
||||||
|
怎么
|
||||||
|
怎么办
|
||||||
|
怎么样
|
||||||
|
怎样
|
||||||
|
咋
|
||||||
|
照
|
||||||
|
照着
|
||||||
|
者
|
||||||
|
这
|
||||||
|
这边
|
||||||
|
这儿
|
||||||
|
这个
|
||||||
|
这会儿
|
||||||
|
这就是说
|
||||||
|
这里
|
||||||
|
这么
|
||||||
|
这么点儿
|
||||||
|
这么些
|
||||||
|
这么样
|
||||||
|
这时
|
||||||
|
这些
|
||||||
|
这样
|
||||||
|
正如
|
||||||
|
吱
|
||||||
|
之
|
||||||
|
之类
|
||||||
|
之所以
|
||||||
|
之一
|
||||||
|
只是
|
||||||
|
只限
|
||||||
|
只要
|
||||||
|
只有
|
||||||
|
至
|
||||||
|
至于
|
||||||
|
诸位
|
||||||
|
着
|
||||||
|
着呢
|
||||||
|
自
|
||||||
|
自从
|
||||||
|
自个儿
|
||||||
|
自各儿
|
||||||
|
自己
|
||||||
|
自家
|
||||||
|
自身
|
||||||
|
综上所述
|
||||||
|
总的来看
|
||||||
|
总的来说
|
||||||
|
总的说来
|
||||||
|
总而言之
|
||||||
|
总之
|
||||||
|
纵
|
||||||
|
纵令
|
||||||
|
纵然
|
||||||
|
纵使
|
||||||
|
遵照
|
||||||
|
作为
|
||||||
|
兮
|
||||||
|
呃
|
||||||
|
呗
|
||||||
|
咚
|
||||||
|
咦
|
||||||
|
喏
|
||||||
|
啐
|
||||||
|
喔唷
|
||||||
|
嗬
|
||||||
|
嗯
|
||||||
|
嗳
|
||||||
276
astrbot/core/knowledge_base/retrieval/manager.py
Normal file
276
astrbot/core/knowledge_base/retrieval/manager.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
"""检索管理器
|
||||||
|
|
||||||
|
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.db.vec_db.base import Result
|
||||||
|
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||||
|
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||||
|
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||||
|
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
||||||
|
from astrbot.core.provider.provider import RerankProvider
|
||||||
|
|
||||||
|
from ..kb_helper import KBHelper
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalResult:
|
||||||
|
"""检索结果"""
|
||||||
|
|
||||||
|
chunk_id: str
|
||||||
|
doc_id: str
|
||||||
|
doc_name: str
|
||||||
|
kb_id: str
|
||||||
|
kb_name: str
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalManager:
|
||||||
|
"""检索管理器
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 协调稠密检索、稀疏检索和 Rerank
|
||||||
|
- 结果融合和排序
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sparse_retriever: SparseRetriever,
|
||||||
|
rank_fusion: RankFusion,
|
||||||
|
kb_db: KBSQLiteDatabase,
|
||||||
|
):
|
||||||
|
"""初始化检索管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vec_db_factory: 向量数据库工厂
|
||||||
|
sparse_retriever: 稀疏检索器
|
||||||
|
rank_fusion: 结果融合器
|
||||||
|
kb_db: 知识库数据库实例
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.sparse_retriever = sparse_retriever
|
||||||
|
self.rank_fusion = rank_fusion
|
||||||
|
self.kb_db = kb_db
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
kb_id_helper_map: dict[str, KBHelper],
|
||||||
|
top_k_fusion: int = 20,
|
||||||
|
top_m_final: int = 5,
|
||||||
|
) -> list[RetrievalResult]:
|
||||||
|
"""混合检索
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 稠密检索 (向量相似度)
|
||||||
|
2. 稀疏检索 (BM25)
|
||||||
|
3. 结果融合 (RRF)
|
||||||
|
4. Rerank 重排序
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
kb_ids: 知识库 ID 列表
|
||||||
|
top_m_final: 最终返回数量
|
||||||
|
enable_rerank: 是否启用 Rerank
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[RetrievalResult]: 检索结果列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not kb_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
kb_options: dict = {}
|
||||||
|
new_kb_ids = []
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
kb_helper = kb_id_helper_map.get(kb_id)
|
||||||
|
if kb_helper:
|
||||||
|
kb = kb_helper.kb
|
||||||
|
kb_options[kb_id] = {
|
||||||
|
"top_k_dense": kb.top_k_dense or 50,
|
||||||
|
"top_k_sparse": kb.top_k_sparse or 50,
|
||||||
|
"top_m_final": kb.top_m_final or 5,
|
||||||
|
"vec_db": kb_helper.vec_db,
|
||||||
|
"rerank_provider_id": kb.rerank_provider_id,
|
||||||
|
}
|
||||||
|
new_kb_ids.append(kb_id)
|
||||||
|
else:
|
||||||
|
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
|
||||||
|
|
||||||
|
kb_ids = new_kb_ids
|
||||||
|
|
||||||
|
# 1. 稠密检索
|
||||||
|
time_start = time.time()
|
||||||
|
dense_results = await self._dense_retrieve(
|
||||||
|
query=query,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
kb_options=kb_options,
|
||||||
|
)
|
||||||
|
time_end = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 稀疏检索
|
||||||
|
time_start = time.time()
|
||||||
|
sparse_results = await self.sparse_retriever.retrieve(
|
||||||
|
query=query,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
kb_options=kb_options,
|
||||||
|
)
|
||||||
|
time_end = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 结果融合
|
||||||
|
time_start = time.time()
|
||||||
|
fused_results = await self.rank_fusion.fuse(
|
||||||
|
dense_results=dense_results,
|
||||||
|
sparse_results=sparse_results,
|
||||||
|
top_k=top_k_fusion,
|
||||||
|
)
|
||||||
|
time_end = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 转换为 RetrievalResult (获取元数据)
|
||||||
|
retrieval_results = []
|
||||||
|
for fr in fused_results:
|
||||||
|
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
|
||||||
|
if metadata_dict:
|
||||||
|
retrieval_results.append(
|
||||||
|
RetrievalResult(
|
||||||
|
chunk_id=fr.chunk_id,
|
||||||
|
doc_id=fr.doc_id,
|
||||||
|
doc_name=metadata_dict["document"].doc_name,
|
||||||
|
kb_id=fr.kb_id,
|
||||||
|
kb_name=metadata_dict["knowledge_base"].kb_name,
|
||||||
|
content=fr.content,
|
||||||
|
score=fr.score,
|
||||||
|
metadata={
|
||||||
|
"chunk_index": fr.chunk_index,
|
||||||
|
"char_count": len(fr.content),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Rerank
|
||||||
|
first_rerank = None
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||||
|
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||||
|
if (
|
||||||
|
vec_db
|
||||||
|
and vec_db.rerank_provider
|
||||||
|
and rerank_pi
|
||||||
|
and rerank_pi == vec_db.rerank_provider.meta().id
|
||||||
|
):
|
||||||
|
first_rerank = vec_db.rerank_provider
|
||||||
|
break
|
||||||
|
if first_rerank and retrieval_results:
|
||||||
|
retrieval_results = await self._rerank(
|
||||||
|
query=query,
|
||||||
|
results=retrieval_results,
|
||||||
|
top_k=top_m_final,
|
||||||
|
rerank_provider=first_rerank,
|
||||||
|
)
|
||||||
|
|
||||||
|
return retrieval_results[:top_m_final]
|
||||||
|
|
||||||
|
async def _dense_retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
kb_options: dict,
|
||||||
|
):
|
||||||
|
"""稠密检索 (向量相似度)
|
||||||
|
|
||||||
|
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
kb_ids: 知识库 ID 列表
|
||||||
|
top_k: 返回结果数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Result]: 检索结果列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
all_results: list[Result] = []
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
if kb_id not in kb_options:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||||
|
dense_k = int(kb_options[kb_id]["top_k_dense"])
|
||||||
|
vec_results = await vec_db.retrieve(
|
||||||
|
query=query,
|
||||||
|
k=dense_k,
|
||||||
|
fetch_k=dense_k * 2,
|
||||||
|
rerank=False, # 稠密检索阶段不进行 rerank
|
||||||
|
metadata_filters={"kb_id": kb_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
all_results.extend(vec_results)
|
||||||
|
except Exception as e:
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 按相似度排序并返回 top_k
|
||||||
|
all_results.sort(key=lambda x: x.similarity, reverse=True)
|
||||||
|
# return all_results[: len(all_results) // len(kb_ids)]
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
async def _rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
results: list[RetrievalResult],
|
||||||
|
top_k: int,
|
||||||
|
rerank_provider: RerankProvider,
|
||||||
|
) -> list[RetrievalResult]:
|
||||||
|
"""Rerank 重排序
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
results: 检索结果列表
|
||||||
|
top_k: 返回结果数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[RetrievalResult]: 重排序后的结果列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 准备文档列表
|
||||||
|
docs = [r.content for r in results]
|
||||||
|
|
||||||
|
# 调用 Rerank Provider
|
||||||
|
rerank_results = await rerank_provider.rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新分数并重新排序
|
||||||
|
reranked_list = []
|
||||||
|
for rerank_result in rerank_results:
|
||||||
|
idx = rerank_result.index
|
||||||
|
if idx < len(results):
|
||||||
|
result = results[idx]
|
||||||
|
result.score = rerank_result.relevance_score
|
||||||
|
reranked_list.append(result)
|
||||||
|
|
||||||
|
reranked_list.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
|
return reranked_list[:top_k]
|
||||||
142
astrbot/core/knowledge_base/retrieval/rank_fusion.py
Normal file
142
astrbot/core/knowledge_base/retrieval/rank_fusion.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""检索结果融合器
|
||||||
|
|
||||||
|
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from astrbot.core.db.vec_db.base import Result
|
||||||
|
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||||
|
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FusedResult:
|
||||||
|
"""融合后的检索结果"""
|
||||||
|
|
||||||
|
chunk_id: str
|
||||||
|
chunk_index: int
|
||||||
|
doc_id: str
|
||||||
|
kb_id: str
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class RankFusion:
|
||||||
|
"""检索结果融合器
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 融合稠密检索和稀疏检索的结果
|
||||||
|
- 使用 Reciprocal Rank Fusion (RRF) 算法
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
|
||||||
|
"""初始化结果融合器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_db: 知识库数据库实例
|
||||||
|
k: RRF 参数,用于平滑排名
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.kb_db = kb_db
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
async def fuse(
|
||||||
|
self,
|
||||||
|
dense_results: list[Result],
|
||||||
|
sparse_results: list[SparseResult],
|
||||||
|
top_k: int = 20,
|
||||||
|
) -> list[FusedResult]:
|
||||||
|
"""融合稠密和稀疏检索结果
|
||||||
|
|
||||||
|
RRF 公式:
|
||||||
|
score(doc) = sum(1 / (k + rank_i))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dense_results: 稠密检索结果
|
||||||
|
sparse_results: 稀疏检索结果
|
||||||
|
top_k: 返回结果数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[FusedResult]: 融合后的结果列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 1. 构建排名映射
|
||||||
|
dense_ranks = {
|
||||||
|
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
|
||||||
|
} # 这里的 doc_id 实际上是 chunk_id
|
||||||
|
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
|
||||||
|
|
||||||
|
# 2. 收集所有唯一的 ID
|
||||||
|
# 需要统一为 chunk_id
|
||||||
|
all_chunk_ids = set()
|
||||||
|
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
|
||||||
|
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
|
||||||
|
|
||||||
|
# 处理稀疏检索结果
|
||||||
|
for r in sparse_results:
|
||||||
|
all_chunk_ids.add(r.chunk_id)
|
||||||
|
chunk_id_to_sparse[r.chunk_id] = r
|
||||||
|
|
||||||
|
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
|
||||||
|
for r in dense_results:
|
||||||
|
vec_doc_id = r.data["doc_id"]
|
||||||
|
all_chunk_ids.add(vec_doc_id)
|
||||||
|
vec_doc_id_to_dense[vec_doc_id] = r
|
||||||
|
|
||||||
|
# 3. 计算 RRF 分数
|
||||||
|
rrf_scores: dict[str, float] = {}
|
||||||
|
|
||||||
|
for identifier in all_chunk_ids:
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
# 来自稠密检索的贡献
|
||||||
|
if identifier in dense_ranks:
|
||||||
|
score += 1.0 / (self.k + dense_ranks[identifier])
|
||||||
|
|
||||||
|
# 来自稀疏检索的贡献
|
||||||
|
if identifier in sparse_ranks:
|
||||||
|
score += 1.0 / (self.k + sparse_ranks[identifier])
|
||||||
|
|
||||||
|
rrf_scores[identifier] = score
|
||||||
|
|
||||||
|
# 4. 排序
|
||||||
|
sorted_ids = sorted(
|
||||||
|
rrf_scores.keys(),
|
||||||
|
key=lambda cid: rrf_scores[cid],
|
||||||
|
reverse=True,
|
||||||
|
)[:top_k]
|
||||||
|
|
||||||
|
# 5. 构建融合结果
|
||||||
|
fused_results = []
|
||||||
|
for identifier in sorted_ids:
|
||||||
|
# 优先从稀疏检索获取完整信息
|
||||||
|
if identifier in chunk_id_to_sparse:
|
||||||
|
sr = chunk_id_to_sparse[identifier]
|
||||||
|
fused_results.append(
|
||||||
|
FusedResult(
|
||||||
|
chunk_id=sr.chunk_id,
|
||||||
|
chunk_index=sr.chunk_index,
|
||||||
|
doc_id=sr.doc_id,
|
||||||
|
kb_id=sr.kb_id,
|
||||||
|
content=sr.content,
|
||||||
|
score=rrf_scores[identifier],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif identifier in vec_doc_id_to_dense:
|
||||||
|
# 从向量检索获取信息,需要从数据库获取块的详细信息
|
||||||
|
vec_result = vec_doc_id_to_dense[identifier]
|
||||||
|
chunk_md = json.loads(vec_result.data["metadata"])
|
||||||
|
fused_results.append(
|
||||||
|
FusedResult(
|
||||||
|
chunk_id=identifier,
|
||||||
|
chunk_index=chunk_md["chunk_index"],
|
||||||
|
doc_id=chunk_md["kb_doc_id"],
|
||||||
|
kb_id=chunk_md["kb_id"],
|
||||||
|
content=vec_result.data["text"],
|
||||||
|
score=rrf_scores[identifier],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fused_results
|
||||||
136
astrbot/core/knowledge_base/retrieval/sparse_retriever.py
Normal file
136
astrbot/core/knowledge_base/retrieval/sparse_retriever.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""稀疏检索器
|
||||||
|
|
||||||
|
使用 BM25 算法进行基于关键词的文档检索
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from rank_bm25 import BM25Okapi
|
||||||
|
|
||||||
|
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||||
|
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SparseResult:
|
||||||
|
"""稀疏检索结果"""
|
||||||
|
|
||||||
|
chunk_index: int
|
||||||
|
chunk_id: str
|
||||||
|
doc_id: str
|
||||||
|
kb_id: str
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class SparseRetriever:
|
||||||
|
"""BM25 稀疏检索器
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 基于关键词的文档检索
|
||||||
|
- 使用 BM25 算法计算相关度
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||||
|
"""初始化稀疏检索器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_db: 知识库数据库实例
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.kb_db = kb_db
|
||||||
|
self._index_cache = {} # 缓存 BM25 索引
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
|
||||||
|
encoding="utf-8",
|
||||||
|
) as f:
|
||||||
|
self.hit_stopwords = {
|
||||||
|
word.strip() for word in set(f.read().splitlines()) if word.strip()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
kb_options: dict,
|
||||||
|
) -> list[SparseResult]:
|
||||||
|
"""执行稀疏检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
kb_ids: 知识库 ID 列表
|
||||||
|
kb_options: 每个知识库的检索选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SparseResult]: 检索结果列表
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 1. 获取所有相关块
|
||||||
|
top_k_sparse = 0
|
||||||
|
chunks = []
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
|
||||||
|
if not vec_db:
|
||||||
|
continue
|
||||||
|
result = await vec_db.document_storage.get_documents(
|
||||||
|
metadata_filters={},
|
||||||
|
limit=None,
|
||||||
|
offset=None,
|
||||||
|
)
|
||||||
|
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
|
||||||
|
result = [
|
||||||
|
{
|
||||||
|
"chunk_id": doc["doc_id"],
|
||||||
|
"chunk_index": chunk_md["chunk_index"],
|
||||||
|
"doc_id": chunk_md["kb_doc_id"],
|
||||||
|
"kb_id": kb_id,
|
||||||
|
"text": doc["text"],
|
||||||
|
}
|
||||||
|
for doc, chunk_md in zip(result, chunk_mds)
|
||||||
|
]
|
||||||
|
chunks.extend(result)
|
||||||
|
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 2. 准备文档和索引
|
||||||
|
corpus = [chunk["text"] for chunk in chunks]
|
||||||
|
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
|
||||||
|
tokenized_corpus = [
|
||||||
|
[word for word in doc if word not in self.hit_stopwords]
|
||||||
|
for doc in tokenized_corpus
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. 构建 BM25 索引
|
||||||
|
bm25 = BM25Okapi(tokenized_corpus)
|
||||||
|
|
||||||
|
# 4. 执行检索
|
||||||
|
tokenized_query = list(jieba.cut(query))
|
||||||
|
tokenized_query = [
|
||||||
|
word for word in tokenized_query if word not in self.hit_stopwords
|
||||||
|
]
|
||||||
|
scores = bm25.get_scores(tokenized_query)
|
||||||
|
|
||||||
|
# 5. 排序并返回 Top-K
|
||||||
|
results = []
|
||||||
|
for idx, score in enumerate(scores):
|
||||||
|
chunk = chunks[idx]
|
||||||
|
results.append(
|
||||||
|
SparseResult(
|
||||||
|
chunk_id=chunk["chunk_id"],
|
||||||
|
chunk_index=chunk["chunk_index"],
|
||||||
|
doc_id=chunk["doc_id"],
|
||||||
|
kb_id=chunk["kb_id"],
|
||||||
|
content=chunk["text"],
|
||||||
|
score=float(score),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
# return results[: len(results) // len(kb_ids)]
|
||||||
|
return results[:top_k_sparse]
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
"""
|
"""日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
||||||
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
|
||||||
|
|
||||||
const:
|
const:
|
||||||
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
|
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
|
||||||
@@ -21,14 +20,14 @@ function:
|
|||||||
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
|
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import colorlog
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections import deque
|
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from typing import List
|
from collections import deque
|
||||||
|
|
||||||
|
import colorlog
|
||||||
|
|
||||||
# 日志缓存大小
|
# 日志缓存大小
|
||||||
CACHED_SIZE = 200
|
CACHED_SIZE = 200
|
||||||
@@ -52,6 +51,7 @@ def is_plugin_path(pathname):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 如果路径来自插件目录,则返回 True,否则返回 False
|
bool: 如果路径来自插件目录,则返回 True,否则返回 False
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not pathname:
|
if not pathname:
|
||||||
return False
|
return False
|
||||||
@@ -68,6 +68,7 @@ def get_short_level_name(level_name):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 四个字母的日志级别缩写
|
str: 四个字母的日志级别缩写
|
||||||
|
|
||||||
"""
|
"""
|
||||||
level_map = {
|
level_map = {
|
||||||
"DEBUG": "DBUG",
|
"DEBUG": "DBUG",
|
||||||
@@ -87,13 +88,14 @@ class LogBroker:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
||||||
self.subscribers: List[Queue] = [] # 订阅者列表
|
self.subscribers: list[Queue] = [] # 订阅者列表
|
||||||
|
|
||||||
def register(self) -> Queue:
|
def register(self) -> Queue:
|
||||||
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
|
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Queue: 订阅者的队列, 可用于接收日志消息
|
Queue: 订阅者的队列, 可用于接收日志消息
|
||||||
|
|
||||||
"""
|
"""
|
||||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||||
self.subscribers.append(q)
|
self.subscribers.append(q)
|
||||||
@@ -104,6 +106,7 @@ class LogBroker:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
q (Queue): 需要取消订阅的队列
|
q (Queue): 需要取消订阅的队列
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.subscribers.remove(q)
|
self.subscribers.remove(q)
|
||||||
|
|
||||||
@@ -113,6 +116,7 @@ class LogBroker:
|
|||||||
Args:
|
Args:
|
||||||
log_entry (dict): 日志消息, 包含日志级别和日志内容.
|
log_entry (dict): 日志消息, 包含日志级别和日志内容.
|
||||||
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
|
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.log_cache.append(log_entry)
|
self.log_cache.append(log_entry)
|
||||||
for q in self.subscribers:
|
for q in self.subscribers:
|
||||||
@@ -138,6 +142,7 @@ class LogQueueHandler(logging.Handler):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
||||||
|
|
||||||
"""
|
"""
|
||||||
log_entry = self.format(record)
|
log_entry = self.format(record)
|
||||||
self.log_broker.publish(
|
self.log_broker.publish(
|
||||||
@@ -145,7 +150,7 @@ class LogQueueHandler(logging.Handler):
|
|||||||
"level": record.levelname,
|
"level": record.levelname,
|
||||||
"time": record.asctime,
|
"time": record.asctime,
|
||||||
"data": log_entry,
|
"data": log_entry,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -164,6 +169,7 @@ class LogManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logging.Logger: 返回配置好的日志记录器
|
logging.Logger: 返回配置好的日志记录器
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logger = logging.getLogger(log_name)
|
logger = logging.getLogger(log_name)
|
||||||
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
|
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
|
||||||
@@ -171,10 +177,10 @@ class LogManager:
|
|||||||
return logger
|
return logger
|
||||||
# 如果logger没有处理器
|
# 如果logger没有处理器
|
||||||
console_handler = logging.StreamHandler(
|
console_handler = logging.StreamHandler(
|
||||||
sys.stdout
|
sys.stdout,
|
||||||
) # 创建一个StreamHandler用于控制台输出
|
) # 创建一个StreamHandler用于控制台输出
|
||||||
console_handler.setLevel(
|
console_handler.setLevel(
|
||||||
logging.DEBUG
|
logging.DEBUG,
|
||||||
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||||
|
|
||||||
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
|
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
|
||||||
@@ -195,7 +201,8 @@ class LogManager:
|
|||||||
|
|
||||||
class FileNameFilter(logging.Filter):
|
class FileNameFilter(logging.Filter):
|
||||||
"""文件名过滤器类, 用于修改日志记录的文件名格式
|
"""文件名过滤器类, 用于修改日志记录的文件名格式
|
||||||
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
|
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式
|
||||||
|
"""
|
||||||
|
|
||||||
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
||||||
def filter(self, record):
|
def filter(self, record):
|
||||||
@@ -231,6 +238,7 @@ class LogManager:
|
|||||||
Args:
|
Args:
|
||||||
logger (logging.Logger): 日志记录器
|
logger (logging.Logger): 日志记录器
|
||||||
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
|
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
|
||||||
|
|
||||||
"""
|
"""
|
||||||
handler = LogQueueHandler(log_broker)
|
handler = LogQueueHandler(log_broker)
|
||||||
handler.setLevel(logging.DEBUG)
|
handler.setLevel(logging.DEBUG)
|
||||||
@@ -240,7 +248,7 @@ class LogManager:
|
|||||||
# 为队列处理器设置相同格式的formatter
|
# 为队列处理器设置相同格式的formatter
|
||||||
handler.setFormatter(
|
handler.setFormatter(
|
||||||
logging.Formatter(
|
logging.Formatter(
|
||||||
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
|
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""
|
"""MIT License
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2021 Lxns-Network
|
Copyright (c) 2021 Lxns-Network
|
||||||
|
|
||||||
@@ -26,7 +25,6 @@ import asyncio
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import typing as T
|
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -38,60 +36,36 @@ from astrbot.core.utils.io import download_file, download_image_by_url, file_to_
|
|||||||
|
|
||||||
|
|
||||||
class ComponentType(str, Enum):
|
class ComponentType(str, Enum):
|
||||||
Plain = "Plain" # 纯文本消息
|
# Basic Segment Types
|
||||||
Face = "Face" # QQ表情
|
Plain = "Plain" # plain text message
|
||||||
Record = "Record" # 语音
|
Image = "Image" # image
|
||||||
Video = "Video" # 视频
|
Record = "Record" # audio
|
||||||
At = "At" # At
|
Video = "Video" # video
|
||||||
Node = "Node" # 转发消息的一个节点
|
File = "File" # file attachment
|
||||||
Nodes = "Nodes" # 转发消息的多个节点
|
|
||||||
Poke = "Poke" # QQ 戳一戳
|
|
||||||
Image = "Image" # 图片
|
|
||||||
Reply = "Reply" # 回复
|
|
||||||
Forward = "Forward" # 转发消息
|
|
||||||
File = "File" # 文件
|
|
||||||
|
|
||||||
|
# IM-specific Segment Types
|
||||||
|
Face = "Face" # Emoji segment for Tencent QQ platform
|
||||||
|
At = "At" # mention a user in IM apps
|
||||||
|
Node = "Node" # a node in a forwarded message
|
||||||
|
Nodes = "Nodes" # a forwarded message consisting of multiple nodes
|
||||||
|
Poke = "Poke" # a poke message for Tencent QQ platform
|
||||||
|
Reply = "Reply" # a reply message segment
|
||||||
|
Forward = "Forward" # a forwarded message segment
|
||||||
RPS = "RPS" # TODO
|
RPS = "RPS" # TODO
|
||||||
Dice = "Dice" # TODO
|
Dice = "Dice" # TODO
|
||||||
Shake = "Shake" # TODO
|
Shake = "Shake" # TODO
|
||||||
Anonymous = "Anonymous" # TODO
|
|
||||||
Share = "Share"
|
Share = "Share"
|
||||||
Contact = "Contact" # TODO
|
Contact = "Contact" # TODO
|
||||||
Location = "Location" # TODO
|
Location = "Location" # TODO
|
||||||
Music = "Music"
|
Music = "Music"
|
||||||
RedBag = "RedBag"
|
|
||||||
Xml = "Xml"
|
|
||||||
Json = "Json"
|
Json = "Json"
|
||||||
CardImage = "CardImage"
|
|
||||||
TTS = "TTS"
|
|
||||||
Unknown = "Unknown"
|
Unknown = "Unknown"
|
||||||
|
|
||||||
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
|
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageComponent(BaseModel):
|
class BaseMessageComponent(BaseModel):
|
||||||
type: ComponentType
|
type: ComponentType
|
||||||
|
|
||||||
def toString(self):
|
|
||||||
output = f"[CQ:{self.type.lower()}"
|
|
||||||
for k, v in self.__dict__.items():
|
|
||||||
if k == "type" or v is None:
|
|
||||||
continue
|
|
||||||
if k == "_type":
|
|
||||||
k = "type"
|
|
||||||
if isinstance(v, bool):
|
|
||||||
v = 1 if v else 0
|
|
||||||
output += ",%s=%s" % (
|
|
||||||
k,
|
|
||||||
str(v)
|
|
||||||
.replace("&", "&")
|
|
||||||
.replace(",", ",")
|
|
||||||
.replace("[", "[")
|
|
||||||
.replace("]", "]"),
|
|
||||||
)
|
|
||||||
output += "]"
|
|
||||||
return output
|
|
||||||
|
|
||||||
def toDict(self):
|
def toDict(self):
|
||||||
data = {}
|
data = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
@@ -110,18 +84,11 @@ class BaseMessageComponent(BaseModel):
|
|||||||
class Plain(BaseMessageComponent):
|
class Plain(BaseMessageComponent):
|
||||||
type = ComponentType.Plain
|
type = ComponentType.Plain
|
||||||
text: str
|
text: str
|
||||||
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
convert: bool | None = True
|
||||||
|
|
||||||
def __init__(self, text: str, convert: bool = True, **_):
|
def __init__(self, text: str, convert: bool = True, **_):
|
||||||
super().__init__(text=text, convert=convert, **_)
|
super().__init__(text=text, convert=convert, **_)
|
||||||
|
|
||||||
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
|
|
||||||
if not self.convert:
|
|
||||||
return self.text
|
|
||||||
return (
|
|
||||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
|
||||||
)
|
|
||||||
|
|
||||||
def toDict(self):
|
def toDict(self):
|
||||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||||
|
|
||||||
@@ -139,17 +106,17 @@ class Face(BaseMessageComponent):
|
|||||||
|
|
||||||
class Record(BaseMessageComponent):
|
class Record(BaseMessageComponent):
|
||||||
type = ComponentType.Record
|
type = ComponentType.Record
|
||||||
file: T.Optional[str] = ""
|
file: str | None = ""
|
||||||
magic: T.Optional[bool] = False
|
magic: bool | None = False
|
||||||
url: T.Optional[str] = ""
|
url: str | None = ""
|
||||||
cache: T.Optional[bool] = True
|
cache: bool | None = True
|
||||||
proxy: T.Optional[bool] = True
|
proxy: bool | None = True
|
||||||
timeout: T.Optional[int] = 0
|
timeout: int | None = 0
|
||||||
# 额外
|
# 额外
|
||||||
path: T.Optional[str]
|
path: str | None
|
||||||
|
|
||||||
def __init__(self, file: T.Optional[str], **_):
|
def __init__(self, file: str | None, **_):
|
||||||
for k in _.keys():
|
for k in _:
|
||||||
if k == "url":
|
if k == "url":
|
||||||
pass
|
pass
|
||||||
# Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}")
|
# Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}")
|
||||||
@@ -174,15 +141,16 @@ class Record(BaseMessageComponent):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 语音的本地路径,以绝对路径表示。
|
str: 语音的本地路径,以绝对路径表示。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self.file:
|
if not self.file:
|
||||||
raise Exception(f"not a valid file: {self.file}")
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
if self.file.startswith("file:///"):
|
if self.file.startswith("file:///"):
|
||||||
return self.file[8:]
|
return self.file[8:]
|
||||||
elif self.file.startswith("http"):
|
if self.file.startswith("http"):
|
||||||
file_path = await download_image_by_url(self.file)
|
file_path = await download_image_by_url(self.file)
|
||||||
return os.path.abspath(file_path)
|
return os.path.abspath(file_path)
|
||||||
elif self.file.startswith("base64://"):
|
if self.file.startswith("base64://"):
|
||||||
bs64_data = self.file.removeprefix("base64://")
|
bs64_data = self.file.removeprefix("base64://")
|
||||||
image_bytes = base64.b64decode(bs64_data)
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
@@ -190,16 +158,16 @@ class Record(BaseMessageComponent):
|
|||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
return os.path.abspath(file_path)
|
return os.path.abspath(file_path)
|
||||||
elif os.path.exists(self.file):
|
if os.path.exists(self.file):
|
||||||
return os.path.abspath(self.file)
|
return os.path.abspath(self.file)
|
||||||
else:
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
raise Exception(f"not a valid file: {self.file}")
|
|
||||||
|
|
||||||
async def convert_to_base64(self) -> str:
|
async def convert_to_base64(self) -> str:
|
||||||
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
|
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# convert to base64
|
# convert to base64
|
||||||
if not self.file:
|
if not self.file:
|
||||||
@@ -219,14 +187,14 @@ class Record(BaseMessageComponent):
|
|||||||
return bs64_data
|
return bs64_data
|
||||||
|
|
||||||
async def register_to_file_service(self) -> str:
|
async def register_to_file_service(self) -> str:
|
||||||
"""
|
"""将语音注册到文件服务。
|
||||||
将语音注册到文件服务。
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 注册后的URL
|
str: 注册后的URL
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: 如果未配置 callback_api_base
|
Exception: 如果未配置 callback_api_base
|
||||||
|
|
||||||
"""
|
"""
|
||||||
callback_host = astrbot_config.get("callback_api_base")
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
@@ -245,10 +213,10 @@ class Record(BaseMessageComponent):
|
|||||||
class Video(BaseMessageComponent):
|
class Video(BaseMessageComponent):
|
||||||
type = ComponentType.Video
|
type = ComponentType.Video
|
||||||
file: str
|
file: str
|
||||||
cover: T.Optional[str] = ""
|
cover: str | None = ""
|
||||||
c: T.Optional[int] = 2
|
c: int | None = 2
|
||||||
# 额外
|
# 额外
|
||||||
path: T.Optional[str] = ""
|
path: str | None = ""
|
||||||
|
|
||||||
def __init__(self, file: str, **_):
|
def __init__(self, file: str, **_):
|
||||||
super().__init__(file=file, **_)
|
super().__init__(file=file, **_)
|
||||||
@@ -268,32 +236,31 @@ class Video(BaseMessageComponent):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 视频的本地路径,以绝对路径表示。
|
str: 视频的本地路径,以绝对路径表示。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
url = self.file
|
url = self.file
|
||||||
if url and url.startswith("file:///"):
|
if url and url.startswith("file:///"):
|
||||||
return url[8:]
|
return url[8:]
|
||||||
elif url and url.startswith("http"):
|
if url and url.startswith("http"):
|
||||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||||
await download_file(url, video_file_path)
|
await download_file(url, video_file_path)
|
||||||
if os.path.exists(video_file_path):
|
if os.path.exists(video_file_path):
|
||||||
return os.path.abspath(video_file_path)
|
return os.path.abspath(video_file_path)
|
||||||
else:
|
raise Exception(f"download failed: {url}")
|
||||||
raise Exception(f"download failed: {url}")
|
if os.path.exists(url):
|
||||||
elif os.path.exists(url):
|
|
||||||
return os.path.abspath(url)
|
return os.path.abspath(url)
|
||||||
else:
|
raise Exception(f"not a valid file: {url}")
|
||||||
raise Exception(f"not a valid file: {url}")
|
|
||||||
|
|
||||||
async def register_to_file_service(self):
|
async def register_to_file_service(self):
|
||||||
"""
|
"""将视频注册到文件服务。
|
||||||
将视频注册到文件服务。
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 注册后的URL
|
str: 注册后的URL
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: 如果未配置 callback_api_base
|
Exception: 如果未配置 callback_api_base
|
||||||
|
|
||||||
"""
|
"""
|
||||||
callback_host = astrbot_config.get("callback_api_base")
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
@@ -330,8 +297,8 @@ class Video(BaseMessageComponent):
|
|||||||
|
|
||||||
class At(BaseMessageComponent):
|
class At(BaseMessageComponent):
|
||||||
type = ComponentType.At
|
type = ComponentType.At
|
||||||
qq: T.Union[int, str] # 此处str为all时代表所有人
|
qq: int | str # 此处str为all时代表所有人
|
||||||
name: T.Optional[str] = ""
|
name: str | None = ""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
@@ -371,20 +338,12 @@ class Shake(BaseMessageComponent): # TODO
|
|||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Anonymous(BaseMessageComponent): # TODO
|
|
||||||
type = ComponentType.Anonymous
|
|
||||||
ignore: T.Optional[bool] = False
|
|
||||||
|
|
||||||
def __init__(self, **_):
|
|
||||||
super().__init__(**_)
|
|
||||||
|
|
||||||
|
|
||||||
class Share(BaseMessageComponent):
|
class Share(BaseMessageComponent):
|
||||||
type = ComponentType.Share
|
type = ComponentType.Share
|
||||||
url: str
|
url: str
|
||||||
title: str
|
title: str
|
||||||
content: T.Optional[str] = ""
|
content: str | None = ""
|
||||||
image: T.Optional[str] = ""
|
image: str | None = ""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
@@ -393,7 +352,7 @@ class Share(BaseMessageComponent):
|
|||||||
class Contact(BaseMessageComponent): # TODO
|
class Contact(BaseMessageComponent): # TODO
|
||||||
type = ComponentType.Contact
|
type = ComponentType.Contact
|
||||||
_type: str # type 字段冲突
|
_type: str # type 字段冲突
|
||||||
id: T.Optional[int] = 0
|
id: int | None = 0
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
@@ -403,8 +362,8 @@ class Location(BaseMessageComponent): # TODO
|
|||||||
type = ComponentType.Location
|
type = ComponentType.Location
|
||||||
lat: float
|
lat: float
|
||||||
lon: float
|
lon: float
|
||||||
title: T.Optional[str] = ""
|
title: str | None = ""
|
||||||
content: T.Optional[str] = ""
|
content: str | None = ""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
@@ -413,12 +372,12 @@ class Location(BaseMessageComponent): # TODO
|
|||||||
class Music(BaseMessageComponent):
|
class Music(BaseMessageComponent):
|
||||||
type = ComponentType.Music
|
type = ComponentType.Music
|
||||||
_type: str
|
_type: str
|
||||||
id: T.Optional[int] = 0
|
id: int | None = 0
|
||||||
url: T.Optional[str] = ""
|
url: str | None = ""
|
||||||
audio: T.Optional[str] = ""
|
audio: str | None = ""
|
||||||
title: T.Optional[str] = ""
|
title: str | None = ""
|
||||||
content: T.Optional[str] = ""
|
content: str | None = ""
|
||||||
image: T.Optional[str] = ""
|
image: str | None = ""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
# for k in _.keys():
|
# for k in _.keys():
|
||||||
@@ -429,18 +388,18 @@ class Music(BaseMessageComponent):
|
|||||||
|
|
||||||
class Image(BaseMessageComponent):
|
class Image(BaseMessageComponent):
|
||||||
type = ComponentType.Image
|
type = ComponentType.Image
|
||||||
file: T.Optional[str] = ""
|
file: str | None = ""
|
||||||
_type: T.Optional[str] = ""
|
_type: str | None = ""
|
||||||
subType: T.Optional[int] = 0
|
subType: int | None = 0
|
||||||
url: T.Optional[str] = ""
|
url: str | None = ""
|
||||||
cache: T.Optional[bool] = True
|
cache: bool | None = True
|
||||||
id: T.Optional[int] = 40000
|
id: int | None = 40000
|
||||||
c: T.Optional[int] = 2
|
c: int | None = 2
|
||||||
# 额外
|
# 额外
|
||||||
path: T.Optional[str] = ""
|
path: str | None = ""
|
||||||
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
|
||||||
|
|
||||||
def __init__(self, file: T.Optional[str], **_):
|
def __init__(self, file: str | None, **_):
|
||||||
super().__init__(file=file, **_)
|
super().__init__(file=file, **_)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -470,16 +429,17 @@ class Image(BaseMessageComponent):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 图片的本地路径,以绝对路径表示。
|
str: 图片的本地路径,以绝对路径表示。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
url = self.url or self.file
|
url = self.url or self.file
|
||||||
if not url:
|
if not url:
|
||||||
raise ValueError("No valid file or URL provided")
|
raise ValueError("No valid file or URL provided")
|
||||||
if url.startswith("file:///"):
|
if url.startswith("file:///"):
|
||||||
return url[8:]
|
return url[8:]
|
||||||
elif url.startswith("http"):
|
if url.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(url)
|
image_file_path = await download_image_by_url(url)
|
||||||
return os.path.abspath(image_file_path)
|
return os.path.abspath(image_file_path)
|
||||||
elif url.startswith("base64://"):
|
if url.startswith("base64://"):
|
||||||
bs64_data = url.removeprefix("base64://")
|
bs64_data = url.removeprefix("base64://")
|
||||||
image_bytes = base64.b64decode(bs64_data)
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
@@ -487,16 +447,16 @@ class Image(BaseMessageComponent):
|
|||||||
with open(image_file_path, "wb") as f:
|
with open(image_file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
return os.path.abspath(image_file_path)
|
return os.path.abspath(image_file_path)
|
||||||
elif os.path.exists(url):
|
if os.path.exists(url):
|
||||||
return os.path.abspath(url)
|
return os.path.abspath(url)
|
||||||
else:
|
raise Exception(f"not a valid file: {url}")
|
||||||
raise Exception(f"not a valid file: {url}")
|
|
||||||
|
|
||||||
async def convert_to_base64(self) -> str:
|
async def convert_to_base64(self) -> str:
|
||||||
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
|
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# convert to base64
|
# convert to base64
|
||||||
url = self.url or self.file
|
url = self.url or self.file
|
||||||
@@ -517,14 +477,14 @@ class Image(BaseMessageComponent):
|
|||||||
return bs64_data
|
return bs64_data
|
||||||
|
|
||||||
async def register_to_file_service(self) -> str:
|
async def register_to_file_service(self) -> str:
|
||||||
"""
|
"""将图片注册到文件服务。
|
||||||
将图片注册到文件服务。
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 注册后的URL
|
str: 注册后的URL
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: 如果未配置 callback_api_base
|
Exception: 如果未配置 callback_api_base
|
||||||
|
|
||||||
"""
|
"""
|
||||||
callback_host = astrbot_config.get("callback_api_base")
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
@@ -542,42 +502,34 @@ class Image(BaseMessageComponent):
|
|||||||
|
|
||||||
class Reply(BaseMessageComponent):
|
class Reply(BaseMessageComponent):
|
||||||
type = ComponentType.Reply
|
type = ComponentType.Reply
|
||||||
id: T.Union[str, int]
|
id: str | int
|
||||||
"""所引用的消息 ID"""
|
"""所引用的消息 ID"""
|
||||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
chain: list["BaseMessageComponent"] | None = []
|
||||||
"""被引用的消息段列表"""
|
"""被引用的消息段列表"""
|
||||||
sender_id: T.Optional[int] | T.Optional[str] = 0
|
sender_id: int | None | str = 0
|
||||||
"""被引用的消息对应的发送者的 ID"""
|
"""被引用的消息对应的发送者的 ID"""
|
||||||
sender_nickname: T.Optional[str] = ""
|
sender_nickname: str | None = ""
|
||||||
"""被引用的消息对应的发送者的昵称"""
|
"""被引用的消息对应的发送者的昵称"""
|
||||||
time: T.Optional[int] = 0
|
time: int | None = 0
|
||||||
"""被引用的消息发送时间"""
|
"""被引用的消息发送时间"""
|
||||||
message_str: T.Optional[str] = ""
|
message_str: str | None = ""
|
||||||
"""被引用的消息解析后的纯文本消息字符串"""
|
"""被引用的消息解析后的纯文本消息字符串"""
|
||||||
|
|
||||||
text: T.Optional[str] = ""
|
text: str | None = ""
|
||||||
"""deprecated"""
|
"""deprecated"""
|
||||||
qq: T.Optional[int] = 0
|
qq: int | None = 0
|
||||||
"""deprecated"""
|
"""deprecated"""
|
||||||
seq: T.Optional[int] = 0
|
seq: int | None = 0
|
||||||
"""deprecated"""
|
"""deprecated"""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class RedBag(BaseMessageComponent):
|
|
||||||
type = ComponentType.RedBag
|
|
||||||
title: str
|
|
||||||
|
|
||||||
def __init__(self, **_):
|
|
||||||
super().__init__(**_)
|
|
||||||
|
|
||||||
|
|
||||||
class Poke(BaseMessageComponent):
|
class Poke(BaseMessageComponent):
|
||||||
type: str = ComponentType.Poke
|
type: str = ComponentType.Poke
|
||||||
id: T.Optional[int] = 0
|
id: int | None = 0
|
||||||
qq: T.Optional[int] = 0
|
qq: int | None = 0
|
||||||
|
|
||||||
def __init__(self, type: str, **_):
|
def __init__(self, type: str, **_):
|
||||||
type = f"Poke:{type}"
|
type = f"Poke:{type}"
|
||||||
@@ -596,12 +548,12 @@ class Node(BaseMessageComponent):
|
|||||||
"""群合并转发消息"""
|
"""群合并转发消息"""
|
||||||
|
|
||||||
type = ComponentType.Node
|
type = ComponentType.Node
|
||||||
id: T.Optional[int] = 0 # 忽略
|
id: int | None = 0 # 忽略
|
||||||
name: T.Optional[str] = "" # qq昵称
|
name: str | None = "" # qq昵称
|
||||||
uin: T.Optional[str] = "0" # qq号
|
uin: str | None = "0" # qq号
|
||||||
content: T.Optional[list[BaseMessageComponent]] = []
|
content: list[BaseMessageComponent] | None = []
|
||||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
seq: str | list | None = "" # 忽略
|
||||||
time: T.Optional[int] = 0 # 忽略
|
time: int | None = 0 # 忽略
|
||||||
|
|
||||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||||
if isinstance(content, Node):
|
if isinstance(content, Node):
|
||||||
@@ -619,7 +571,7 @@ class Node(BaseMessageComponent):
|
|||||||
{
|
{
|
||||||
"type": comp.type.lower(),
|
"type": comp.type.lower(),
|
||||||
"data": {"file": f"base64://{bs64}"},
|
"data": {"file": f"base64://{bs64}"},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
elif isinstance(comp, Plain):
|
elif isinstance(comp, Plain):
|
||||||
# For Plain segments, we need to handle the plain differently
|
# For Plain segments, we need to handle the plain differently
|
||||||
@@ -648,9 +600,9 @@ class Node(BaseMessageComponent):
|
|||||||
|
|
||||||
class Nodes(BaseMessageComponent):
|
class Nodes(BaseMessageComponent):
|
||||||
type = ComponentType.Nodes
|
type = ComponentType.Nodes
|
||||||
nodes: T.List[Node]
|
nodes: list[Node]
|
||||||
|
|
||||||
def __init__(self, nodes: T.List[Node], **_):
|
def __init__(self, nodes: list[Node], **_):
|
||||||
super().__init__(nodes=nodes, **_)
|
super().__init__(nodes=nodes, **_)
|
||||||
|
|
||||||
def toDict(self):
|
def toDict(self):
|
||||||
@@ -672,19 +624,10 @@ class Nodes(BaseMessageComponent):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class Xml(BaseMessageComponent):
|
|
||||||
type = ComponentType.Xml
|
|
||||||
data: str
|
|
||||||
resid: T.Optional[int] = 0
|
|
||||||
|
|
||||||
def __init__(self, **_):
|
|
||||||
super().__init__(**_)
|
|
||||||
|
|
||||||
|
|
||||||
class Json(BaseMessageComponent):
|
class Json(BaseMessageComponent):
|
||||||
type = ComponentType.Json
|
type = ComponentType.Json
|
||||||
data: T.Union[str, dict]
|
data: str | dict
|
||||||
resid: T.Optional[int] = 0
|
resid: int | None = 0
|
||||||
|
|
||||||
def __init__(self, data, **_):
|
def __init__(self, data, **_):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
@@ -692,50 +635,18 @@ class Json(BaseMessageComponent):
|
|||||||
super().__init__(data=data, **_)
|
super().__init__(data=data, **_)
|
||||||
|
|
||||||
|
|
||||||
class CardImage(BaseMessageComponent):
|
|
||||||
type = ComponentType.CardImage
|
|
||||||
file: str
|
|
||||||
cache: T.Optional[bool] = True
|
|
||||||
minwidth: T.Optional[int] = 400
|
|
||||||
minheight: T.Optional[int] = 400
|
|
||||||
maxwidth: T.Optional[int] = 500
|
|
||||||
maxheight: T.Optional[int] = 500
|
|
||||||
source: T.Optional[str] = ""
|
|
||||||
icon: T.Optional[str] = ""
|
|
||||||
|
|
||||||
def __init__(self, **_):
|
|
||||||
super().__init__(**_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fromFileSystem(path, **_):
|
|
||||||
return CardImage(file=f"file:///{os.path.abspath(path)}", **_)
|
|
||||||
|
|
||||||
|
|
||||||
class TTS(BaseMessageComponent):
|
|
||||||
type = ComponentType.TTS
|
|
||||||
text: str
|
|
||||||
|
|
||||||
def __init__(self, **_):
|
|
||||||
super().__init__(**_)
|
|
||||||
|
|
||||||
|
|
||||||
class Unknown(BaseMessageComponent):
|
class Unknown(BaseMessageComponent):
|
||||||
type = ComponentType.Unknown
|
type = ComponentType.Unknown
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
def toString(self):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class File(BaseMessageComponent):
|
class File(BaseMessageComponent):
|
||||||
"""
|
"""文件消息段"""
|
||||||
文件消息段
|
|
||||||
"""
|
|
||||||
|
|
||||||
type = ComponentType.File
|
type = ComponentType.File
|
||||||
name: T.Optional[str] = "" # 名字
|
name: str | None = "" # 名字
|
||||||
file_: T.Optional[str] = "" # 本地路径
|
file_: str | None = "" # 本地路径
|
||||||
url: T.Optional[str] = "" # url
|
url: str | None = "" # url
|
||||||
|
|
||||||
def __init__(self, name: str, file: str = "", url: str = ""):
|
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||||
"""文件消息段。"""
|
"""文件消息段。"""
|
||||||
@@ -743,11 +654,11 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def file(self) -> str:
|
def file(self) -> str:
|
||||||
"""
|
"""获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||||
获取文件路径,如果文件不存在但有URL,则同步下载文件
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 文件路径
|
str: 文件路径
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.file_ and os.path.exists(self.file_):
|
if self.file_ and os.path.exists(self.file_):
|
||||||
return os.path.abspath(self.file_)
|
return os.path.abspath(self.file_)
|
||||||
@@ -757,19 +668,16 @@ class File(BaseMessageComponent):
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if loop.is_running():
|
if loop.is_running():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
(
|
"不可以在异步上下文中同步等待下载! "
|
||||||
"不可以在异步上下文中同步等待下载! "
|
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||||
"请使用 await get_file() 代替直接获取 <File>.file 字段"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
else:
|
# 等待下载完成
|
||||||
# 等待下载完成
|
loop.run_until_complete(self._download_file())
|
||||||
loop.run_until_complete(self._download_file())
|
|
||||||
|
|
||||||
if self.file_ and os.path.exists(self.file_):
|
if self.file_ and os.path.exists(self.file_):
|
||||||
return os.path.abspath(self.file_)
|
return os.path.abspath(self.file_)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"文件下载失败: {e}")
|
logger.error(f"文件下载失败: {e}")
|
||||||
|
|
||||||
@@ -777,11 +685,11 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
@file.setter
|
@file.setter
|
||||||
def file(self, value: str):
|
def file(self, value: str):
|
||||||
"""
|
"""向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||||
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
value (str): 文件路径或URL
|
value (str): 文件路径或URL
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if value.startswith("http://") or value.startswith("https://"):
|
if value.startswith("http://") or value.startswith("https://"):
|
||||||
self.url = value
|
self.url = value
|
||||||
@@ -796,6 +704,7 @@ class File(BaseMessageComponent):
|
|||||||
注意,如果为 True,也可能返回文件路径。
|
注意,如果为 True,也可能返回文件路径。
|
||||||
Returns:
|
Returns:
|
||||||
str: 文件路径或者 http 下载链接
|
str: 文件路径或者 http 下载链接
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if allow_return_url and self.url:
|
if allow_return_url and self.url:
|
||||||
return self.url
|
return self.url
|
||||||
@@ -818,14 +727,14 @@ class File(BaseMessageComponent):
|
|||||||
self.file_ = os.path.abspath(file_path)
|
self.file_ = os.path.abspath(file_path)
|
||||||
|
|
||||||
async def register_to_file_service(self):
|
async def register_to_file_service(self):
|
||||||
"""
|
"""将文件注册到文件服务。
|
||||||
将文件注册到文件服务。
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 注册后的URL
|
str: 注册后的URL
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: 如果未配置 callback_api_base
|
Exception: 如果未配置 callback_api_base
|
||||||
|
|
||||||
"""
|
"""
|
||||||
callback_host = astrbot_config.get("callback_api_base")
|
callback_host = astrbot_config.get("callback_api_base")
|
||||||
|
|
||||||
@@ -863,41 +772,38 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
class WechatEmoji(BaseMessageComponent):
|
class WechatEmoji(BaseMessageComponent):
|
||||||
type = ComponentType.WechatEmoji
|
type = ComponentType.WechatEmoji
|
||||||
md5: T.Optional[str] = ""
|
md5: str | None = ""
|
||||||
md5_len: T.Optional[int] = 0
|
md5_len: int | None = 0
|
||||||
cdnurl: T.Optional[str] = ""
|
cdnurl: str | None = ""
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
ComponentTypes = {
|
ComponentTypes = {
|
||||||
|
# Basic Message Segments
|
||||||
"plain": Plain,
|
"plain": Plain,
|
||||||
"text": Plain,
|
"text": Plain,
|
||||||
"face": Face,
|
"image": Image,
|
||||||
"record": Record,
|
"record": Record,
|
||||||
"video": Video,
|
"video": Video,
|
||||||
|
"file": File,
|
||||||
|
# IM-specific Message Segments
|
||||||
|
"face": Face,
|
||||||
"at": At,
|
"at": At,
|
||||||
"rps": RPS,
|
"rps": RPS,
|
||||||
"dice": Dice,
|
"dice": Dice,
|
||||||
"shake": Shake,
|
"shake": Shake,
|
||||||
"anonymous": Anonymous,
|
|
||||||
"share": Share,
|
"share": Share,
|
||||||
"contact": Contact,
|
"contact": Contact,
|
||||||
"location": Location,
|
"location": Location,
|
||||||
"music": Music,
|
"music": Music,
|
||||||
"image": Image,
|
|
||||||
"reply": Reply,
|
"reply": Reply,
|
||||||
"redbag": RedBag,
|
|
||||||
"poke": Poke,
|
"poke": Poke,
|
||||||
"forward": Forward,
|
"forward": Forward,
|
||||||
"node": Node,
|
"node": Node,
|
||||||
"nodes": Nodes,
|
"nodes": Nodes,
|
||||||
"xml": Xml,
|
|
||||||
"json": Json,
|
"json": Json,
|
||||||
"cardimage": CardImage,
|
|
||||||
"tts": TTS,
|
|
||||||
"unknown": Unknown,
|
"unknown": Unknown,
|
||||||
"file": File,
|
|
||||||
"WechatEmoji": WechatEmoji,
|
"WechatEmoji": WechatEmoji,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
import enum
|
import enum
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from typing import List, Optional, Union, AsyncGenerator
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from astrbot.core.message.components import (
|
from astrbot.core.message.components import (
|
||||||
BaseMessageComponent,
|
|
||||||
Plain,
|
|
||||||
Image,
|
|
||||||
At,
|
At,
|
||||||
AtAll,
|
AtAll,
|
||||||
|
BaseMessageComponent,
|
||||||
|
Image,
|
||||||
|
Plain,
|
||||||
)
|
)
|
||||||
from typing_extensions import deprecated
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -20,18 +21,18 @@ class MessageChain:
|
|||||||
Attributes:
|
Attributes:
|
||||||
`chain` (list): 用于顺序存储各个组件。
|
`chain` (list): 用于顺序存储各个组件。
|
||||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
chain: list[BaseMessageComponent] = field(default_factory=list)
|
||||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
use_t2i_: bool | None = None # None 为跟随用户设置
|
||||||
type: Optional[str] = None
|
type: str | None = None
|
||||||
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
||||||
|
|
||||||
def message(self, message: str):
|
def message(self, message: str):
|
||||||
"""添加一条文本消息到消息链 `chain` 中。
|
"""添加一条文本消息到消息链 `chain` 中。
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
CommandResult().message("Hello ").message("world!")
|
CommandResult().message("Hello ").message("world!")
|
||||||
# 输出 Hello world!
|
# 输出 Hello world!
|
||||||
|
|
||||||
@@ -39,11 +40,10 @@ class MessageChain:
|
|||||||
self.chain.append(Plain(message))
|
self.chain.append(Plain(message))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def at(self, name: str, qq: Union[str, int]):
|
def at(self, name: str, qq: str | int):
|
||||||
"""添加一条 At 消息到消息链 `chain` 中。
|
"""添加一条 At 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
CommandResult().at("张三", "12345678910")
|
CommandResult().at("张三", "12345678910")
|
||||||
# 输出 @张三
|
# 输出 @张三
|
||||||
|
|
||||||
@@ -55,7 +55,6 @@ class MessageChain:
|
|||||||
"""添加一条 AtAll 消息到消息链 `chain` 中。
|
"""添加一条 AtAll 消息到消息链 `chain` 中。
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
CommandResult().at_all()
|
CommandResult().at_all()
|
||||||
# 输出 @所有人
|
# 输出 @所有人
|
||||||
|
|
||||||
@@ -68,7 +67,6 @@ class MessageChain:
|
|||||||
"""添加一条错误消息到消息链 `chain` 中
|
"""添加一条错误消息到消息链 `chain` 中
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
CommandResult().error("解析失败")
|
CommandResult().error("解析失败")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -82,7 +80,6 @@ class MessageChain:
|
|||||||
如果需要发送本地图片,请使用 `file_image` 方法。
|
如果需要发送本地图片,请使用 `file_image` 方法。
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
CommandResult().image("https://example.com/image.jpg")
|
CommandResult().image("https://example.com/image.jpg")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -96,6 +93,7 @@ class MessageChain:
|
|||||||
如果需要发送网络图片,请使用 `url_image` 方法。
|
如果需要发送网络图片,请使用 `url_image` 方法。
|
||||||
|
|
||||||
CommandResult().image("image.jpg")
|
CommandResult().image("image.jpg")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.chain.append(Image.fromFileSystem(path))
|
self.chain.append(Image.fromFileSystem(path))
|
||||||
return self
|
return self
|
||||||
@@ -114,6 +112,7 @@ class MessageChain:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.use_t2i_ = use_t2i
|
self.use_t2i_ = use_t2i
|
||||||
return self
|
return self
|
||||||
@@ -125,7 +124,7 @@ class MessageChain:
|
|||||||
def squash_plain(self):
|
def squash_plain(self):
|
||||||
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||||
if not self.chain:
|
if not self.chain:
|
||||||
return
|
return None
|
||||||
|
|
||||||
new_chain = []
|
new_chain = []
|
||||||
first_plain = None
|
first_plain = None
|
||||||
@@ -153,6 +152,7 @@ class EventResultType(enum.Enum):
|
|||||||
Attributes:
|
Attributes:
|
||||||
CONTINUE: 事件将会继续传播
|
CONTINUE: 事件将会继续传播
|
||||||
STOP: 事件将会终止传播
|
STOP: 事件将会终止传播
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CONTINUE = enum.auto()
|
CONTINUE = enum.auto()
|
||||||
@@ -181,17 +181,18 @@ class MessageEventResult(MessageChain):
|
|||||||
`chain` (list): 用于顺序存储各个组件。
|
`chain` (list): 用于顺序存储各个组件。
|
||||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||||
`result_type` (EventResultType): 事件处理的结果类型。
|
`result_type` (EventResultType): 事件处理的结果类型。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result_type: Optional[EventResultType] = field(
|
result_type: EventResultType | None = field(
|
||||||
default_factory=lambda: EventResultType.CONTINUE
|
default_factory=lambda: EventResultType.CONTINUE,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_content_type: Optional[ResultContentType] = field(
|
result_content_type: ResultContentType | None = field(
|
||||||
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
default_factory=lambda: ResultContentType.GENERAL_RESULT,
|
||||||
)
|
)
|
||||||
|
|
||||||
async_stream: Optional[AsyncGenerator] = None
|
async_stream: AsyncGenerator | None = None
|
||||||
"""异步流"""
|
"""异步流"""
|
||||||
|
|
||||||
def stop_event(self) -> "MessageEventResult":
|
def stop_event(self) -> "MessageEventResult":
|
||||||
@@ -205,9 +206,7 @@ class MessageEventResult(MessageChain):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def is_stopped(self) -> bool:
|
def is_stopped(self) -> bool:
|
||||||
"""
|
"""是否终止事件传播。"""
|
||||||
是否终止事件传播。
|
|
||||||
"""
|
|
||||||
return self.result_type == EventResultType.STOP
|
return self.result_type == EventResultType.STOP
|
||||||
|
|
||||||
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
||||||
@@ -220,6 +219,7 @@ class MessageEventResult(MessageChain):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
result_type (EventResultType): 事件处理的结果类型。
|
result_type (EventResultType): 事件处理的结果类型。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.result_content_type = typ
|
self.result_content_type = typ
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
|
from astrbot import logger
|
||||||
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.db.po import Persona, Personality
|
from astrbot.core.db.po import Persona, Personality
|
||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
||||||
from astrbot.core.platform.message_session import MessageSession
|
from astrbot.core.platform.message_session import MessageSession
|
||||||
from astrbot import logger
|
|
||||||
|
|
||||||
DEFAULT_PERSONALITY = Personality(
|
DEFAULT_PERSONALITY = Personality(
|
||||||
prompt="You are a helpful and friendly assistant.",
|
prompt="You are a helpful and friendly assistant.",
|
||||||
@@ -41,12 +41,14 @@ class PersonaManager:
|
|||||||
return persona
|
return persona
|
||||||
|
|
||||||
async def get_default_persona_v3(
|
async def get_default_persona_v3(
|
||||||
self, umo: str | MessageSession | None = None
|
self,
|
||||||
|
umo: str | MessageSession | None = None,
|
||||||
) -> Personality:
|
) -> Personality:
|
||||||
"""获取默认 persona"""
|
"""获取默认 persona"""
|
||||||
cfg = self.acm.get_conf(umo)
|
cfg = self.acm.get_conf(umo)
|
||||||
default_persona_id = cfg.get("provider_settings", {}).get(
|
default_persona_id = cfg.get("provider_settings", {}).get(
|
||||||
"default_personality", "default"
|
"default_personality",
|
||||||
|
"default",
|
||||||
)
|
)
|
||||||
if not default_persona_id or default_persona_id == "default":
|
if not default_persona_id or default_persona_id == "default":
|
||||||
return DEFAULT_PERSONALITY
|
return DEFAULT_PERSONALITY
|
||||||
@@ -66,16 +68,19 @@ class PersonaManager:
|
|||||||
async def update_persona(
|
async def update_persona(
|
||||||
self,
|
self,
|
||||||
persona_id: str,
|
persona_id: str,
|
||||||
system_prompt: str = None,
|
system_prompt: str | None = None,
|
||||||
begin_dialogs: list[str] = None,
|
begin_dialogs: list[str] | None = None,
|
||||||
tools: list[str] = None,
|
tools: list[str] | None = None,
|
||||||
):
|
):
|
||||||
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||||
existing_persona = await self.db.get_persona_by_id(persona_id)
|
existing_persona = await self.db.get_persona_by_id(persona_id)
|
||||||
if not existing_persona:
|
if not existing_persona:
|
||||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||||
persona = await self.db.update_persona(
|
persona = await self.db.update_persona(
|
||||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
persona_id,
|
||||||
|
system_prompt,
|
||||||
|
begin_dialogs,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
if persona:
|
if persona:
|
||||||
for i, p in enumerate(self.personas):
|
for i, p in enumerate(self.personas):
|
||||||
@@ -100,7 +105,10 @@ class PersonaManager:
|
|||||||
if await self.db.get_persona_by_id(persona_id):
|
if await self.db.get_persona_by_id(persona_id):
|
||||||
raise ValueError(f"Persona with ID {persona_id} already exists.")
|
raise ValueError(f"Persona with ID {persona_id} already exists.")
|
||||||
new_persona = await self.db.insert_persona(
|
new_persona = await self.db.insert_persona(
|
||||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
persona_id,
|
||||||
|
system_prompt,
|
||||||
|
begin_dialogs,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
self.personas.append(new_persona)
|
self.personas.append(new_persona)
|
||||||
self.get_v3_persona_data()
|
self.get_v3_persona_data()
|
||||||
@@ -115,6 +123,7 @@ class PersonaManager:
|
|||||||
- list[dict]: 包含 persona 配置的字典列表。
|
- list[dict]: 包含 persona 配置的字典列表。
|
||||||
- list[Personality]: 包含 Personality 对象的列表。
|
- list[Personality]: 包含 Personality 对象的列表。
|
||||||
- Personality: 默认选择的 Personality 对象。
|
- Personality: 默认选择的 Personality 对象。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
v3_persona_config = [
|
v3_persona_config = [
|
||||||
{
|
{
|
||||||
@@ -136,7 +145,7 @@ class PersonaManager:
|
|||||||
if begin_dialogs:
|
if begin_dialogs:
|
||||||
if len(begin_dialogs) % 2 != 0:
|
if len(begin_dialogs) % 2 != 0:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。",
|
||||||
)
|
)
|
||||||
begin_dialogs = []
|
begin_dialogs = []
|
||||||
user_turn = True
|
user_turn = True
|
||||||
@@ -146,7 +155,7 @@ class PersonaManager:
|
|||||||
"role": "user" if user_turn else "assistant",
|
"role": "user" if user_turn else "assistant",
|
||||||
"content": dialog,
|
"content": dialog,
|
||||||
"_no_save": None, # 不持久化到 db
|
"_no_save": None, # 不持久化到 db
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
user_turn = not user_turn
|
user_turn = not user_turn
|
||||||
|
|
||||||
|
|||||||
@@ -27,15 +27,15 @@ STAGES_ORDER = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WakingCheckStage",
|
|
||||||
"WhitelistCheckStage",
|
|
||||||
"SessionStatusCheckStage",
|
|
||||||
"RateLimitStage",
|
|
||||||
"ContentSafetyCheckStage",
|
"ContentSafetyCheckStage",
|
||||||
|
"EventResultType",
|
||||||
|
"MessageEventResult",
|
||||||
"PreProcessStage",
|
"PreProcessStage",
|
||||||
"ProcessStage",
|
"ProcessStage",
|
||||||
"ResultDecorateStage",
|
"RateLimitStage",
|
||||||
"RespondStage",
|
"RespondStage",
|
||||||
"MessageEventResult",
|
"ResultDecorateStage",
|
||||||
"EventResultType",
|
"SessionStatusCheckStage",
|
||||||
|
"WakingCheckStage",
|
||||||
|
"WhitelistCheckStage",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from typing import Union, AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from ..stage import Stage, register_stage
|
|
||||||
from ..context import PipelineContext
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.message.message_event_result import MessageEventResult
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from ..stage import Stage, register_stage
|
||||||
from .strategies.strategy import StrategySelector
|
from .strategies.strategy import StrategySelector
|
||||||
|
|
||||||
|
|
||||||
@@ -19,8 +21,10 @@ class ContentSafetyCheckStage(Stage):
|
|||||||
self.strategy_selector = StrategySelector(config)
|
self.strategy_selector = StrategySelector(config)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent, check_text: str | None = None
|
self,
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
event: AstrMessageEvent,
|
||||||
|
check_text: str | None = None,
|
||||||
|
) -> None | AsyncGenerator[None, None]:
|
||||||
"""检查内容安全"""
|
"""检查内容安全"""
|
||||||
text = check_text if check_text else event.get_message_str()
|
text = check_text if check_text else event.get_message_str()
|
||||||
ok, info = self.strategy_selector.check(text)
|
ok, info = self.strategy_selector.check(text)
|
||||||
@@ -28,8 +32,8 @@ class ContentSafetyCheckStage(Stage):
|
|||||||
if event.is_at_or_wake_command:
|
if event.is_at_or_wake_command:
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult().message(
|
MessageEventResult().message(
|
||||||
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"
|
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。",
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import abc
|
import abc
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class ContentSafetyStrategy(abc.ABC):
|
class ContentSafetyStrategy(abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def check(self, content: str) -> Tuple[bool, str]:
|
def check(self, content: str) -> tuple[bool, str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""
|
"""使用此功能应该先 pip install baidu-aip"""
|
||||||
使用此功能应该先 pip install baidu-aip
|
|
||||||
"""
|
from aip import AipContentCensor
|
||||||
|
|
||||||
from . import ContentSafetyStrategy
|
from . import ContentSafetyStrategy
|
||||||
from aip import AipContentCensor
|
|
||||||
|
|
||||||
|
|
||||||
class BaiduAipStrategy(ContentSafetyStrategy):
|
class BaiduAipStrategy(ContentSafetyStrategy):
|
||||||
@@ -19,12 +18,12 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
|||||||
return False, ""
|
return False, ""
|
||||||
if res["conclusionType"] == 1:
|
if res["conclusionType"] == 1:
|
||||||
return True, ""
|
return True, ""
|
||||||
else:
|
if "data" not in res:
|
||||||
if "data" not in res:
|
return False, ""
|
||||||
return False, ""
|
count = len(res["data"])
|
||||||
count = len(res["data"])
|
parts = [f"百度审核服务发现 {count} 处违规:\n"]
|
||||||
info = f"百度审核服务发现 {count} 处违规:\n"
|
for i in res["data"]:
|
||||||
for i in res["data"]:
|
parts.append(f"{i['msg']};\n")
|
||||||
info += f"{i['msg']};\n"
|
parts.append("\n判断结果:" + res["conclusion"])
|
||||||
info += "\n判断结果:" + res["conclusion"]
|
info = "".join(parts)
|
||||||
return False, info
|
return False, info
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from . import ContentSafetyStrategy
|
from . import ContentSafetyStrategy
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
from . import ContentSafetyStrategy
|
|
||||||
from typing import List, Tuple
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
|
||||||
|
from . import ContentSafetyStrategy
|
||||||
|
|
||||||
|
|
||||||
class StrategySelector:
|
class StrategySelector:
|
||||||
def __init__(self, config: dict) -> None:
|
def __init__(self, config: dict) -> None:
|
||||||
self.enabled_strategies: List[ContentSafetyStrategy] = []
|
self.enabled_strategies: list[ContentSafetyStrategy] = []
|
||||||
if config["internal_keywords"]["enable"]:
|
if config["internal_keywords"]["enable"]:
|
||||||
from .keywords import KeywordsStrategy
|
from .keywords import KeywordsStrategy
|
||||||
|
|
||||||
self.enabled_strategies.append(
|
self.enabled_strategies.append(
|
||||||
KeywordsStrategy(config["internal_keywords"]["extra_keywords"])
|
KeywordsStrategy(config["internal_keywords"]["extra_keywords"]),
|
||||||
)
|
)
|
||||||
if config["baidu_aip"]["enable"]:
|
if config["baidu_aip"]["enable"]:
|
||||||
try:
|
try:
|
||||||
@@ -23,10 +23,10 @@ class StrategySelector:
|
|||||||
config["baidu_aip"]["app_id"],
|
config["baidu_aip"]["app_id"],
|
||||||
config["baidu_aip"]["api_key"],
|
config["baidu_aip"]["api_key"],
|
||||||
config["baidu_aip"]["secret_key"],
|
config["baidu_aip"]["secret_key"],
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def check(self, content: str) -> Tuple[bool, str]:
|
def check(self, content: str) -> tuple[bool, str]:
|
||||||
for strategy in self.enabled_strategies:
|
for strategy in self.enabled_strategies:
|
||||||
ok, info = strategy.check(content)
|
ok, info = strategy.check(content)
|
||||||
if not ok:
|
if not ok:
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
from astrbot.core.star import PluginManager
|
from astrbot.core.star import PluginManager
|
||||||
from .context_utils import call_handler, call_event_hook
|
|
||||||
|
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -13,3 +15,4 @@ class PipelineContext:
|
|||||||
astrbot_config_id: str
|
astrbot_config_id: str
|
||||||
call_handler = call_handler
|
call_handler = call_handler
|
||||||
call_event_hook = call_event_hook
|
call_event_hook = call_event_hook
|
||||||
|
call_local_llm_tool = call_local_llm_tool
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
import typing as T
|
import typing as T
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.agent.run_context import ContextWrapper
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||||
|
|
||||||
|
|
||||||
async def call_handler(
|
async def call_handler(
|
||||||
@@ -26,6 +29,7 @@ async def call_handler(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ready_to_call = None # 一个协程或者异步生成器
|
ready_to_call = None # 一个协程或者异步生成器
|
||||||
|
|
||||||
@@ -80,14 +84,17 @@ async def call_event_hook(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 如果事件被终止,返回 True
|
bool: 如果事件被终止,返回 True
|
||||||
#"""
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
hook_type, plugins_name=event.plugins_name
|
hook_type,
|
||||||
|
plugins_name=event.plugins_name,
|
||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||||
)
|
)
|
||||||
await handler.handler(event, *args, **kwargs)
|
await handler.handler(event, *args, **kwargs)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
@@ -95,8 +102,71 @@ async def call_event_hook(
|
|||||||
|
|
||||||
if event.is_stopped():
|
if event.is_stopped():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。",
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return event.is_stopped()
|
return event.is_stopped()
|
||||||
|
|
||||||
|
|
||||||
|
async def call_local_llm_tool(
|
||||||
|
context: ContextWrapper[AstrAgentContext],
|
||||||
|
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||||
|
method_name: str,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> T.AsyncGenerator[T.Any, None]:
|
||||||
|
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||||||
|
ready_to_call = None # 一个协程或者异步生成器
|
||||||
|
|
||||||
|
trace_ = None
|
||||||
|
|
||||||
|
event = context.context.event
|
||||||
|
|
||||||
|
try:
|
||||||
|
if method_name == "run" or method_name == "decorator_handler":
|
||||||
|
ready_to_call = handler(event, *args, **kwargs)
|
||||||
|
elif method_name == "call":
|
||||||
|
ready_to_call = handler(context, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"未知的方法名: {method_name}")
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||||
|
except TypeError:
|
||||||
|
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||||
|
except Exception as e:
|
||||||
|
trace_ = traceback.format_exc()
|
||||||
|
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||||
|
|
||||||
|
if not ready_to_call:
|
||||||
|
return
|
||||||
|
|
||||||
|
if inspect.isasyncgen(ready_to_call):
|
||||||
|
_has_yielded = False
|
||||||
|
try:
|
||||||
|
async for ret in ready_to_call:
|
||||||
|
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||||
|
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||||
|
_has_yielded = True
|
||||||
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
|
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||||
|
event.set_result(ret)
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
# 如果返回值是 None, 则不设置结果并继续
|
||||||
|
# 继续执行后续阶段
|
||||||
|
yield ret
|
||||||
|
if not _has_yielded:
|
||||||
|
# 如果这个异步生成器没有执行到 yield 分支
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Previous Error: {trace_}")
|
||||||
|
raise e
|
||||||
|
elif inspect.iscoroutine(ready_to_call):
|
||||||
|
# 如果只是一个协程, 直接执行
|
||||||
|
ret = await ready_to_call
|
||||||
|
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||||
|
event.set_result(ret)
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield ret
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import traceback
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
from typing import Union, AsyncGenerator
|
import traceback
|
||||||
from ..stage import Stage, register_stage
|
from collections.abc import AsyncGenerator
|
||||||
from ..context import PipelineContext
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import Plain, Record, Image
|
from astrbot.core.message.components import Image, Plain, Record
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from ..stage import Stage, register_stage
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
@@ -20,8 +22,9 @@ class PreProcessStage(Stage):
|
|||||||
self.platform_settings: dict = self.config.get("platform_settings", {})
|
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self,
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
event: AstrMessageEvent,
|
||||||
|
) -> None | AsyncGenerator[None, None]:
|
||||||
"""在处理事件之前的预处理"""
|
"""在处理事件之前的预处理"""
|
||||||
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
|
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
|
||||||
supported = {"telegram", "lark"}
|
supported = {"telegram", "lark"}
|
||||||
@@ -68,7 +71,7 @@ class PreProcessStage(Stage):
|
|||||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||||
if not stt_provider:
|
if not stt_provider:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
|
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
message_chain = event.get_messages()
|
message_chain = event.get_messages()
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
"""
|
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||||
本地 Agent 模式的 LLM 调用 Stage
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import timedelta
|
from collections.abc import AsyncGenerator
|
||||||
from typing import AsyncGenerator, Union
|
from typing import Any
|
||||||
from astrbot.core.conversation_mgr import Conversation
|
|
||||||
|
from mcp.types import CallToolResult
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.agent.handoff import HandoffTool
|
||||||
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||||
|
from astrbot.core.agent.mcp_client import MCPTool
|
||||||
|
from astrbot.core.agent.run_context import ContextWrapper
|
||||||
|
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||||
|
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||||
|
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||||
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
|
from astrbot.core.conversation_mgr import Conversation
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
@@ -22,20 +31,14 @@ from astrbot.core.provider.entities import (
|
|||||||
LLMResponse,
|
LLMResponse,
|
||||||
ProviderRequest,
|
ProviderRequest,
|
||||||
)
|
)
|
||||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
||||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
|
||||||
from astrbot.core.agent.run_context import ContextWrapper
|
|
||||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
|
||||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
|
||||||
from astrbot.core.agent.handoff import HandoffTool
|
|
||||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
|
||||||
from astrbot.core.star.star_handler import EventType
|
|
||||||
from astrbot.core.utils.metrics import Metric
|
|
||||||
from ...context import PipelineContext, call_event_hook, call_handler
|
|
||||||
from ..stage import Stage
|
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
from astrbot.core.star.star_handler import star_map
|
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.star.star_handler import EventType, star_map
|
||||||
|
from astrbot.core.utils.metrics import Metric
|
||||||
|
|
||||||
|
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
|
||||||
|
from ..stage import Stage
|
||||||
|
from ..utils import inject_kb_context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import mcp
|
import mcp
|
||||||
@@ -44,7 +47,7 @@ except (ModuleNotFoundError, ImportError):
|
|||||||
|
|
||||||
|
|
||||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||||
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
|
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||||
|
|
||||||
|
|
||||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||||
@@ -58,23 +61,22 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(tool, HandoffTool):
|
if isinstance(tool, HandoffTool):
|
||||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||||
yield r
|
yield r
|
||||||
return
|
return
|
||||||
|
|
||||||
if tool.origin == "local":
|
elif isinstance(tool, MCPTool):
|
||||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
|
||||||
yield r
|
|
||||||
return
|
|
||||||
|
|
||||||
elif tool.origin == "mcp":
|
|
||||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||||
yield r
|
yield r
|
||||||
return
|
return
|
||||||
|
|
||||||
raise Exception(f"Unknown function origin: {tool.origin}")
|
else:
|
||||||
|
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||||
|
yield r
|
||||||
|
return
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _execute_handoff(
|
async def _execute_handoff(
|
||||||
@@ -102,7 +104,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
|
|
||||||
request = ProviderRequest(
|
request = ProviderRequest(
|
||||||
prompt=input_,
|
prompt=input_,
|
||||||
system_prompt=tool.description,
|
system_prompt=tool.description or "",
|
||||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||||
func_tool=toolset,
|
func_tool=toolset,
|
||||||
@@ -112,18 +114,22 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
first_provider_request=run_context.context.first_provider_request,
|
first_provider_request=run_context.context.first_provider_request,
|
||||||
curr_provider_request=request,
|
curr_provider_request=request,
|
||||||
streaming=run_context.context.streaming,
|
streaming=run_context.context.streaming,
|
||||||
|
event=run_context.context.event,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event = run_context.context.event
|
||||||
|
|
||||||
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
||||||
await run_context.event.send(
|
await event.send(
|
||||||
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
|
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent_runner.reset(
|
await agent_runner.reset(
|
||||||
provider=run_context.context.provider,
|
provider=run_context.context.provider,
|
||||||
request=request,
|
request=request,
|
||||||
run_context=AgentContextWrapper(
|
run_context=AgentContextWrapper(
|
||||||
context=astr_agent_ctx, event=run_context.event
|
context=astr_agent_ctx,
|
||||||
|
tool_call_timeout=run_context.tool_call_timeout,
|
||||||
),
|
),
|
||||||
tool_executor=FunctionToolExecutor(),
|
tool_executor=FunctionToolExecutor(),
|
||||||
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||||
@@ -145,7 +151,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = (
|
result = (
|
||||||
@@ -173,25 +179,46 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
run_context: ContextWrapper[AstrAgentContext],
|
run_context: ContextWrapper[AstrAgentContext],
|
||||||
**tool_args,
|
**tool_args,
|
||||||
):
|
):
|
||||||
if not run_context.event:
|
event = run_context.context.event
|
||||||
|
if not event:
|
||||||
raise ValueError("Event must be provided for local function tools.")
|
raise ValueError("Event must be provided for local function tools.")
|
||||||
|
|
||||||
# 检查 tool 下有没有 run 方法
|
is_override_call = False
|
||||||
if not tool.handler and not hasattr(tool, "run"):
|
for ty in type(tool).mro():
|
||||||
raise ValueError("Tool must have a valid handler or 'run' method.")
|
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
||||||
awaitable = tool.handler or getattr(tool, "run")
|
logger.debug(f"Found call in: {ty}")
|
||||||
|
is_override_call = True
|
||||||
|
break
|
||||||
|
|
||||||
wrapper = call_handler(
|
# 检查 tool 下有没有 run 方法
|
||||||
event=run_context.event,
|
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
||||||
|
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||||
|
|
||||||
|
awaitable = None
|
||||||
|
method_name = ""
|
||||||
|
if tool.handler:
|
||||||
|
awaitable = tool.handler
|
||||||
|
method_name = "decorator_handler"
|
||||||
|
elif is_override_call:
|
||||||
|
awaitable = tool.call
|
||||||
|
method_name = "call"
|
||||||
|
elif hasattr(tool, "run"):
|
||||||
|
awaitable = getattr(tool, "run")
|
||||||
|
method_name = "run"
|
||||||
|
if awaitable is None:
|
||||||
|
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||||
|
|
||||||
|
wrapper = call_local_llm_tool(
|
||||||
|
context=run_context,
|
||||||
handler=awaitable,
|
handler=awaitable,
|
||||||
|
method_name=method_name,
|
||||||
**tool_args,
|
**tool_args,
|
||||||
)
|
)
|
||||||
# async for resp in wrapper:
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
resp = await asyncio.wait_for(
|
resp = await asyncio.wait_for(
|
||||||
anext(wrapper),
|
anext(wrapper),
|
||||||
timeout=run_context.context.tool_call_timeout,
|
timeout=run_context.tool_call_timeout,
|
||||||
)
|
)
|
||||||
if resp is not None:
|
if resp is not None:
|
||||||
if isinstance(resp, mcp.types.CallToolResult):
|
if isinstance(resp, mcp.types.CallToolResult):
|
||||||
@@ -206,10 +233,24 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||||
|
if res := run_context.context.event.get_result():
|
||||||
|
if res.chain:
|
||||||
|
try:
|
||||||
|
await event.send(
|
||||||
|
MessageChain(
|
||||||
|
chain=res.chain,
|
||||||
|
type="tool_direct_result",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Tool 直接发送消息失败: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
yield None
|
yield None
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
|
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
||||||
)
|
)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
break
|
break
|
||||||
@@ -221,40 +262,41 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
run_context: ContextWrapper[AstrAgentContext],
|
run_context: ContextWrapper[AstrAgentContext],
|
||||||
**tool_args,
|
**tool_args,
|
||||||
):
|
):
|
||||||
if not tool.mcp_client:
|
res = await tool.call(run_context, **tool_args)
|
||||||
raise ValueError("MCP client is not available for MCP function tools.")
|
|
||||||
|
|
||||||
session = tool.mcp_client.session
|
|
||||||
if not session:
|
|
||||||
raise ValueError("MCP session is not available for MCP function tools.")
|
|
||||||
res = await session.call_tool(
|
|
||||||
name=tool.name,
|
|
||||||
arguments=tool_args,
|
|
||||||
read_timeout_seconds=timedelta(
|
|
||||||
seconds=run_context.context.tool_call_timeout
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if not res:
|
if not res:
|
||||||
return
|
return
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
|
||||||
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
|
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||||
async def on_agent_done(self, run_context, llm_response):
|
async def on_agent_done(self, run_context, llm_response):
|
||||||
# 执行事件钩子
|
# 执行事件钩子
|
||||||
await call_event_hook(
|
await call_event_hook(
|
||||||
run_context.event, EventType.OnLLMResponseEvent, llm_response
|
run_context.context.event,
|
||||||
|
EventType.OnLLMResponseEvent,
|
||||||
|
llm_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def on_tool_end(
|
||||||
|
self,
|
||||||
|
run_context: ContextWrapper[AstrAgentContext],
|
||||||
|
tool: FunctionTool[Any],
|
||||||
|
tool_args: dict | None,
|
||||||
|
tool_result: CallToolResult | None,
|
||||||
|
):
|
||||||
|
run_context.context.event.clear_result()
|
||||||
|
|
||||||
|
|
||||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||||
|
|
||||||
|
|
||||||
async def run_agent(
|
async def run_agent(
|
||||||
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
|
agent_runner: AgentRunner,
|
||||||
|
max_step: int = 30,
|
||||||
|
show_tool_use: bool = True,
|
||||||
) -> AsyncGenerator[MessageChain, None]:
|
) -> AsyncGenerator[MessageChain, None]:
|
||||||
step_idx = 0
|
step_idx = 0
|
||||||
astr_event = agent_runner.run_context.event
|
astr_event = agent_runner.run_context.context.event
|
||||||
while step_idx < max_step:
|
while step_idx < max_step:
|
||||||
step_idx += 1
|
step_idx += 1
|
||||||
try:
|
try:
|
||||||
@@ -289,19 +331,18 @@ async def run_agent(
|
|||||||
MessageEventResult(
|
MessageEventResult(
|
||||||
chain=resp.data["chain"].chain,
|
chain=resp.data["chain"].chain,
|
||||||
result_content_type=content_typ,
|
result_content_type=content_typ,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
astr_event.clear_result()
|
astr_event.clear_result()
|
||||||
else:
|
elif resp.type == "streaming_delta":
|
||||||
if resp.type == "streaming_delta":
|
yield resp.data["chain"] # MessageChain
|
||||||
yield resp.data["chain"] # MessageChain
|
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||||
if agent_runner.streaming:
|
if agent_runner.streaming:
|
||||||
yield MessageChain().message(err_msg)
|
yield MessageChain().message(err_msg)
|
||||||
else:
|
else:
|
||||||
@@ -331,13 +372,13 @@ class LLMRequestSubStage(Stage):
|
|||||||
for bwp in self.bot_wake_prefixs:
|
for bwp in self.bot_wake_prefixs:
|
||||||
if self.provider_wake_prefix.startswith(bwp):
|
if self.provider_wake_prefix.startswith(bwp):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。"
|
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
||||||
)
|
)
|
||||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
||||||
|
|
||||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||||
|
|
||||||
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
|
def _select_provider(self, event: AstrMessageEvent):
|
||||||
"""选择使用的 LLM 提供商"""
|
"""选择使用的 LLM 提供商"""
|
||||||
sel_provider = event.get_extra("selected_provider")
|
sel_provider = event.get_extra("selected_provider")
|
||||||
_ctx = self.ctx.plugin_manager.context
|
_ctx = self.ctx.plugin_manager.context
|
||||||
@@ -366,8 +407,10 @@ class LLMRequestSubStage(Stage):
|
|||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent, _nested: bool = False
|
self,
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
event: AstrMessageEvent,
|
||||||
|
_nested: bool = False,
|
||||||
|
) -> None | AsyncGenerator[None, None]:
|
||||||
req: ProviderRequest | None = None
|
req: ProviderRequest | None = None
|
||||||
|
|
||||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||||
@@ -382,6 +425,9 @@ class LLMRequestSubStage(Stage):
|
|||||||
provider = self._select_provider(event)
|
provider = self._select_provider(event)
|
||||||
if provider is None:
|
if provider is None:
|
||||||
return
|
return
|
||||||
|
if not isinstance(provider, Provider):
|
||||||
|
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||||
|
return
|
||||||
|
|
||||||
if event.get_extra("provider_request"):
|
if event.get_extra("provider_request"):
|
||||||
req = event.get_extra("provider_request")
|
req = event.get_extra("provider_request")
|
||||||
@@ -416,6 +462,16 @@ class LLMRequestSubStage(Stage):
|
|||||||
if not req.prompt and not req.image_urls:
|
if not req.prompt and not req.image_urls:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 应用知识库
|
||||||
|
try:
|
||||||
|
await inject_kb_context(
|
||||||
|
umo=event.unified_msg_origin,
|
||||||
|
p_ctx=self.ctx,
|
||||||
|
req=req,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用知识库时遇到问题: {e}")
|
||||||
|
|
||||||
# 执行请求 LLM 前事件钩子。
|
# 执行请求 LLM 前事件钩子。
|
||||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||||
return
|
return
|
||||||
@@ -463,7 +519,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
||||||
if "tool_use" not in provider_cfg:
|
if "tool_use" not in provider_cfg:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
|
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
||||||
)
|
)
|
||||||
req.func_tool = None
|
req.func_tool = None
|
||||||
# 插件可用性设置
|
# 插件可用性设置
|
||||||
@@ -480,22 +536,28 @@ class LLMRequestSubStage(Stage):
|
|||||||
new_tool_set.add_tool(tool)
|
new_tool_set.add_tool(tool)
|
||||||
req.func_tool = new_tool_set
|
req.func_tool = new_tool_set
|
||||||
|
|
||||||
|
# 备份 req.contexts
|
||||||
|
backup_contexts = copy.deepcopy(req.contexts)
|
||||||
|
|
||||||
# run agent
|
# run agent
|
||||||
agent_runner = AgentRunner()
|
agent_runner = AgentRunner()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
|
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||||
)
|
)
|
||||||
astr_agent_ctx = AstrAgentContext(
|
astr_agent_ctx = AstrAgentContext(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
first_provider_request=req,
|
first_provider_request=req,
|
||||||
curr_provider_request=req,
|
curr_provider_request=req,
|
||||||
streaming=self.streaming_response,
|
streaming=self.streaming_response,
|
||||||
tool_call_timeout=self.tool_call_timeout,
|
event=event,
|
||||||
)
|
)
|
||||||
await agent_runner.reset(
|
await agent_runner.reset(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
request=req,
|
request=req,
|
||||||
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
|
run_context=AgentContextWrapper(
|
||||||
|
context=astr_agent_ctx,
|
||||||
|
tool_call_timeout=self.tool_call_timeout,
|
||||||
|
),
|
||||||
tool_executor=FunctionToolExecutor(),
|
tool_executor=FunctionToolExecutor(),
|
||||||
agent_hooks=MAIN_AGENT_HOOKS,
|
agent_hooks=MAIN_AGENT_HOOKS,
|
||||||
streaming=self.streaming_response,
|
streaming=self.streaming_response,
|
||||||
@@ -507,8 +569,8 @@ class LLMRequestSubStage(Stage):
|
|||||||
MessageEventResult()
|
MessageEventResult()
|
||||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||||
.set_async_stream(
|
.set_async_stream(
|
||||||
run_agent(agent_runner, self.max_step, self.show_tool_use)
|
run_agent(agent_runner, self.max_step, self.show_tool_use),
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
@@ -517,18 +579,23 @@ class LLMRequestSubStage(Stage):
|
|||||||
chain = (
|
chain = (
|
||||||
MessageChain().message(final_llm_resp.completion_text).chain
|
MessageChain().message(final_llm_resp.completion_text).chain
|
||||||
)
|
)
|
||||||
else:
|
elif final_llm_resp.result_chain:
|
||||||
chain = final_llm_resp.result_chain.chain
|
chain = final_llm_resp.result_chain.chain
|
||||||
|
else:
|
||||||
|
chain = MessageChain().chain
|
||||||
event.set_result(
|
event.set_result(
|
||||||
MessageEventResult(
|
MessageEventResult(
|
||||||
chain=chain,
|
chain=chain,
|
||||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# 恢复备份的 contexts
|
||||||
|
req.contexts = backup_contexts
|
||||||
|
|
||||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||||
|
|
||||||
# 异步处理 WebChat 特殊情况
|
# 异步处理 WebChat 特殊情况
|
||||||
@@ -540,15 +607,21 @@ class LLMRequestSubStage(Stage):
|
|||||||
llm_tick=1,
|
llm_tick=1,
|
||||||
model_name=agent_runner.provider.get_model(),
|
model_name=agent_runner.provider.get_model(),
|
||||||
provider_type=agent_runner.provider.meta().type,
|
provider_type=agent_runner.provider.meta().type,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_webchat(
|
async def _handle_webchat(
|
||||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
self,
|
||||||
|
event: AstrMessageEvent,
|
||||||
|
req: ProviderRequest,
|
||||||
|
prov: Provider,
|
||||||
):
|
):
|
||||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||||
|
if not req.conversation:
|
||||||
|
return
|
||||||
conversation = await self.conv_manager.get_conversation(
|
conversation = await self.conv_manager.get_conversation(
|
||||||
event.unified_msg_origin, req.conversation.cid
|
event.unified_msg_origin,
|
||||||
|
req.conversation.cid,
|
||||||
)
|
)
|
||||||
if conversation and not req.conversation.title:
|
if conversation and not req.conversation.title:
|
||||||
messages = json.loads(conversation.history)
|
messages = json.loads(conversation.history)
|
||||||
@@ -585,7 +658,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
)
|
)
|
||||||
if llm_resp and llm_resp.completion_text:
|
if llm_resp and llm_resp.completion_text:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
|
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
|
||||||
)
|
)
|
||||||
title = llm_resp.completion_text.strip()
|
title = llm_resp.completion_text.strip()
|
||||||
if not title or "<None>" in title:
|
if not title or "<None>" in title:
|
||||||
@@ -628,7 +701,9 @@ class LLMRequestSubStage(Stage):
|
|||||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||||
await self.conv_manager.update_conversation(
|
await self.conv_manager.update_conversation(
|
||||||
event.unified_msg_origin, req.conversation.cid, history=messages
|
event.unified_msg_origin,
|
||||||
|
req.conversation.cid,
|
||||||
|
history=messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
"""
|
"""本地 Agent 模式的 AstrBot 插件调用 Stage"""
|
||||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
|
||||||
"""
|
import traceback
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.message.message_event_result import MessageEventResult
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
|
|
||||||
from ...context import PipelineContext, call_handler
|
from ...context import PipelineContext, call_handler
|
||||||
from ..stage import Stage
|
from ..stage import Stage
|
||||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core.message.message_event_result import MessageEventResult
|
|
||||||
from astrbot.core import logger
|
|
||||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
|
||||||
from astrbot.core.star.star import star_map
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
|
|
||||||
class StarRequestSubStage(Stage):
|
class StarRequestSubStage(Stage):
|
||||||
@@ -21,13 +22,14 @@ class StarRequestSubStage(Stage):
|
|||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self,
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
event: AstrMessageEvent,
|
||||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
|
) -> None | AsyncGenerator[None, None]:
|
||||||
"activated_handlers"
|
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||||
|
"activated_handlers",
|
||||||
)
|
)
|
||||||
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra(
|
handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra(
|
||||||
"handlers_parsed_params"
|
"handlers_parsed_params",
|
||||||
)
|
)
|
||||||
if not handlers_parsed_params:
|
if not handlers_parsed_params:
|
||||||
handlers_parsed_params = {}
|
handlers_parsed_params = {}
|
||||||
@@ -37,7 +39,7 @@ class StarRequestSubStage(Stage):
|
|||||||
md = star_map.get(handler.handler_module_path)
|
md = star_map.get(handler.handler_module_path)
|
||||||
if not md:
|
if not md:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
|
f"Cannot find plugin for given handler module path: {handler.handler_module_path}",
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from typing import List, Union, AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from ..stage import Stage, register_stage
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
|
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||||
|
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
|
from ..stage import Stage, register_stage
|
||||||
from .method.llm_request import LLMRequestSubStage
|
from .method.llm_request import LLMRequestSubStage
|
||||||
from .method.star_request import StarRequestSubStage
|
from .method.star_request import StarRequestSubStage
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
|
||||||
from astrbot.core.provider.entities import ProviderRequest
|
|
||||||
from astrbot.core import logger
|
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
@@ -22,11 +24,12 @@ class ProcessStage(Stage):
|
|||||||
await self.star_request_sub_stage.initialize(ctx)
|
await self.star_request_sub_stage.initialize(ctx)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self,
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
event: AstrMessageEvent,
|
||||||
|
) -> None | AsyncGenerator[None, None]:
|
||||||
"""处理事件"""
|
"""处理事件"""
|
||||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
|
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||||
"activated_handlers"
|
"activated_handlers",
|
||||||
)
|
)
|
||||||
# 有插件 Handler 被激活
|
# 有插件 Handler 被激活
|
||||||
if activated_handlers:
|
if activated_handlers:
|
||||||
|
|||||||
81
astrbot/core/pipeline/process_stage/utils.py
Normal file
81
astrbot/core/pipeline/process_stage/utils.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from astrbot.api import logger, sp
|
||||||
|
from astrbot.core.provider.entities import ProviderRequest
|
||||||
|
|
||||||
|
from ..context import PipelineContext
|
||||||
|
|
||||||
|
|
||||||
|
async def inject_kb_context(
|
||||||
|
umo: str,
|
||||||
|
p_ctx: PipelineContext,
|
||||||
|
req: ProviderRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Inject knowledge base context into the provider request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
umo: Unique message object (session ID)
|
||||||
|
p_ctx: Pipeline context
|
||||||
|
req: Provider request
|
||||||
|
|
||||||
|
"""
|
||||||
|
kb_mgr = p_ctx.plugin_manager.context.kb_manager
|
||||||
|
|
||||||
|
# 1. 优先读取会话级配置
|
||||||
|
session_config = await sp.session_get(umo, "kb_config", default={})
|
||||||
|
|
||||||
|
if session_config and "kb_ids" in session_config:
|
||||||
|
# 会话级配置
|
||||||
|
kb_ids = session_config.get("kb_ids", [])
|
||||||
|
|
||||||
|
# 如果配置为空列表,明确表示不使用知识库
|
||||||
|
if not kb_ids:
|
||||||
|
logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库")
|
||||||
|
return
|
||||||
|
|
||||||
|
top_k = session_config.get("top_k", 5)
|
||||||
|
|
||||||
|
# 将 kb_ids 转换为 kb_names
|
||||||
|
kb_names = []
|
||||||
|
invalid_kb_ids = []
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
kb_helper = await kb_mgr.get_kb(kb_id)
|
||||||
|
if kb_helper:
|
||||||
|
kb_names.append(kb_helper.kb.kb_name)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}")
|
||||||
|
invalid_kb_ids.append(kb_id)
|
||||||
|
|
||||||
|
if invalid_kb_ids:
|
||||||
|
logger.warning(
|
||||||
|
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not kb_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
|
||||||
|
else:
|
||||||
|
kb_names = p_ctx.astrbot_config.get("kb_names", [])
|
||||||
|
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
|
||||||
|
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
|
||||||
|
|
||||||
|
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
|
||||||
|
|
||||||
|
if not kb_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
|
||||||
|
kb_context = await kb_mgr.retrieve(
|
||||||
|
query=req.prompt,
|
||||||
|
kb_names=kb_names,
|
||||||
|
top_k_fusion=top_k_fusion,
|
||||||
|
top_m_final=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not kb_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
formatted = kb_context.get("context_text", "")
|
||||||
|
if formatted:
|
||||||
|
results = kb_context.get("results", [])
|
||||||
|
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
|
||||||
|
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"
|
||||||
@@ -1,18 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import DefaultDict, Deque, Union, AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from ..stage import Stage, register_stage
|
from datetime import datetime, timedelta
|
||||||
from ..context import PipelineContext
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
|
|
||||||
|
from ..context import PipelineContext
|
||||||
|
from ..stage import Stage, register_stage
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class RateLimitStage(Stage):
|
class RateLimitStage(Stage):
|
||||||
"""
|
"""检查是否需要限制消息发送的限流器。
|
||||||
检查是否需要限制消息发送的限流器。
|
|
||||||
|
|
||||||
使用 Fixed Window 算法。
|
使用 Fixed Window 算法。
|
||||||
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
|
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
|
||||||
@@ -20,32 +21,30 @@ class RateLimitStage(Stage):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 存储每个会话的请求时间队列
|
# 存储每个会话的请求时间队列
|
||||||
self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque)
|
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
|
||||||
# 为每个会话设置一个锁,避免并发冲突
|
# 为每个会话设置一个锁,避免并发冲突
|
||||||
self.locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||||
# 限流参数
|
# 限流参数
|
||||||
self.rate_limit_count: int = 0
|
self.rate_limit_count: int = 0
|
||||||
self.rate_limit_time: timedelta = timedelta(0)
|
self.rate_limit_time: timedelta = timedelta(0)
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
"""
|
"""初始化限流器,根据配置设置限流参数。"""
|
||||||
初始化限流器,根据配置设置限流参数。
|
|
||||||
"""
|
|
||||||
self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
||||||
"count"
|
"count"
|
||||||
]
|
]
|
||||||
self.rate_limit_time = timedelta(
|
self.rate_limit_time = timedelta(
|
||||||
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"]
|
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"],
|
||||||
)
|
)
|
||||||
self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
||||||
"strategy"
|
"strategy"
|
||||||
] # stall or discard
|
] # stall or discard
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self,
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
event: AstrMessageEvent,
|
||||||
"""
|
) -> None | AsyncGenerator[None, None]:
|
||||||
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
|
"""检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (AstrMessageEvent): 当前消息事件。
|
event (AstrMessageEvent): 当前消息事件。
|
||||||
@@ -53,6 +52,7 @@ class RateLimitStage(Stage):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MessageEventResult: 继续或停止事件处理的结果。
|
MessageEventResult: 继续或停止事件处理的结果。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session_id = event.session_id
|
session_id = event.session_id
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
@@ -66,32 +66,33 @@ class RateLimitStage(Stage):
|
|||||||
if len(timestamps) < self.rate_limit_count:
|
if len(timestamps) < self.rate_limit_count:
|
||||||
timestamps.append(now)
|
timestamps.append(now)
|
||||||
break
|
break
|
||||||
else:
|
next_window_time = timestamps[0] + self.rate_limit_time
|
||||||
next_window_time = timestamps[0] + self.rate_limit_time
|
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
||||||
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
|
||||||
|
|
||||||
match self.rl_strategy:
|
match self.rl_strategy:
|
||||||
case RateLimitStrategy.STALL.value:
|
case RateLimitStrategy.STALL.value:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
|
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。",
|
||||||
)
|
)
|
||||||
await asyncio.sleep(stall_duration)
|
await asyncio.sleep(stall_duration)
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
case RateLimitStrategy.DISCARD.value:
|
case RateLimitStrategy.DISCARD.value:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
|
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。",
|
||||||
)
|
)
|
||||||
return event.stop_event()
|
return event.stop_event()
|
||||||
|
|
||||||
def _remove_expired_timestamps(
|
def _remove_expired_timestamps(
|
||||||
self, timestamps: Deque[datetime], now: datetime
|
self,
|
||||||
|
timestamps: deque[datetime],
|
||||||
|
now: datetime,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""移除时间窗口外的时间戳。
|
||||||
移除时间窗口外的时间戳。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timestamps (Deque[datetime]): 当前会话的时间戳队列。
|
timestamps (Deque[datetime]): 当前会话的时间戳队列。
|
||||||
now (datetime): 当前时间,用于计算过期时间。
|
now (datetime): 当前时间,用于计算过期时间。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
expiry_threshold: datetime = now - self.rate_limit_time
|
expiry_threshold: datetime = now - self.rate_limit_time
|
||||||
while timestamps and timestamps[0] < expiry_threshold:
|
while timestamps and timestamps[0] < expiry_threshold:
|
||||||
|
|||||||
@@ -1,25 +1,27 @@
|
|||||||
import random
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from typing import Union, AsyncGenerator
|
|
||||||
from ..stage import register_stage, Stage
|
|
||||||
from ..context import PipelineContext, call_event_hook
|
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
||||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import BaseMessageComponent, ComponentType
|
from astrbot.core.message.components import BaseMessageComponent, ComponentType
|
||||||
|
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||||
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star.star_handler import EventType
|
from astrbot.core.star.star_handler import EventType
|
||||||
from astrbot.core.utils.path_util import path_Mapping
|
from astrbot.core.utils.path_util import path_Mapping
|
||||||
from astrbot.core.utils.session_lock import session_lock_manager
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
|
from ..context import PipelineContext, call_event_hook
|
||||||
|
from ..stage import Stage, register_stage
|
||||||
|
|
||||||
|
|
||||||
@register_stage
|
@register_stage
|
||||||
class RespondStage(Stage):
|
class RespondStage(Stage):
|
||||||
# 组件类型到其非空判断函数的映射
|
# 组件类型到其非空判断函数的映射
|
||||||
_component_validators = {
|
_component_validators = {
|
||||||
Comp.Plain: lambda comp: bool(
|
Comp.Plain: lambda comp: bool(
|
||||||
comp.text and comp.text.strip()
|
comp.text and comp.text.strip(),
|
||||||
), # 纯文本消息需要strip
|
), # 纯文本消息需要strip
|
||||||
Comp.Face: lambda comp: comp.id is not None, # QQ表情
|
Comp.Face: lambda comp: comp.id is not None, # QQ表情
|
||||||
Comp.Record: lambda comp: bool(comp.file), # 语音
|
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||||
@@ -58,7 +60,7 @@ class RespondStage(Stage):
|
|||||||
"segmented_reply"
|
"segmented_reply"
|
||||||
]["interval_method"]
|
]["interval_method"]
|
||||||
self.log_base = float(
|
self.log_base = float(
|
||||||
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"]
|
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"],
|
||||||
)
|
)
|
||||||
interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||||
"interval"
|
"interval"
|
||||||
@@ -86,17 +88,16 @@ class RespondStage(Stage):
|
|||||||
wc = await self._word_cnt(comp.text)
|
wc = await self._word_cnt(comp.text)
|
||||||
i = math.log(wc + 1, self.log_base)
|
i = math.log(wc + 1, self.log_base)
|
||||||
return random.uniform(i, i + 0.5)
|
return random.uniform(i, i + 0.5)
|
||||||
else:
|
return random.uniform(1, 1.75)
|
||||||
return random.uniform(1, 1.75)
|
# random
|
||||||
else:
|
return random.uniform(self.interval[0], self.interval[1])
|
||||||
# random
|
|
||||||
return random.uniform(self.interval[0], self.interval[1])
|
|
||||||
|
|
||||||
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
|
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
|
||||||
"""检查消息链是否为空
|
"""检查消息链是否为空
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chain (list[BaseMessageComponent]): 包含消息对象的列表
|
chain (list[BaseMessageComponent]): 包含消息对象的列表
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not chain:
|
if not chain:
|
||||||
return True
|
return True
|
||||||
@@ -150,8 +151,9 @@ class RespondStage(Stage):
|
|||||||
return extracted
|
return extracted
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self,
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
event: AstrMessageEvent,
|
||||||
|
) -> None | AsyncGenerator[None, None]:
|
||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
@@ -159,7 +161,7 @@ class RespondStage(Stage):
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
@@ -168,12 +170,13 @@ class RespondStage(Stage):
|
|||||||
return
|
return
|
||||||
# 流式结果直接交付平台适配器处理
|
# 流式结果直接交付平台适配器处理
|
||||||
use_fallback = self.config.get("provider_settings", {}).get(
|
use_fallback = self.config.get("provider_settings", {}).get(
|
||||||
"streaming_segmented", False
|
"streaming_segmented",
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
logger.info(f"应用流式输出({event.get_platform_id()})")
|
logger.info(f"应用流式输出({event.get_platform_id()})")
|
||||||
await event.send_streaming(result.async_stream, use_fallback)
|
await event.send_streaming(result.async_stream, use_fallback)
|
||||||
return
|
return
|
||||||
elif len(result.chain) > 0:
|
if len(result.chain) > 0:
|
||||||
# 检查路径映射
|
# 检查路径映射
|
||||||
if mappings := self.platform_settings.get("path_mapping", []):
|
if mappings := self.platform_settings.get("path_mapping", []):
|
||||||
for idx, component in enumerate(result.chain):
|
for idx, component in enumerate(result.chain):
|
||||||
@@ -212,7 +215,7 @@ class RespondStage(Stage):
|
|||||||
if not result.chain or len(result.chain) == 0:
|
if not result.chain or len(result.chain) == 0:
|
||||||
# may fix #2670
|
# may fix #2670
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
|
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||||
@@ -237,7 +240,7 @@ class RespondStage(Stage):
|
|||||||
):
|
):
|
||||||
# may fix #2670
|
# may fix #2670
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
|
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
sep_comps = self._extract_comp(
|
sep_comps = self._extract_comp(
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user