Compare commits

..

1 Commits

430 changed files with 15785 additions and 35389 deletions

View File

@@ -1,8 +1,9 @@
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# github actions # github acions
.github/ .github/
.*ignore .*ignore
.git/
# User-specific stuff # User-specific stuff
.idea/ .idea/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
@@ -14,6 +15,7 @@ env/
venv*/ venv*/
ENV/ ENV/
.conda/ .conda/
README*.md
dashboard/ dashboard/
data/ data/
changelogs/ changelogs/

View File

@@ -16,7 +16,7 @@ body:
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。 请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来. 不熟悉 JSON 现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
- type: textarea - type: textarea
id: plugin-info id: plugin-info
@@ -26,13 +26,12 @@ body:
value: | value: |
```json ```json
{ {
"name": "插件名,请以 astrbot_plugin_ 开头", "name": "插件名",
"display_name": "用于展示的插件名,方便人类阅读", "desc": "插件介绍",
"desc": "插件的简短介绍",
"author": "作者名", "author": "作者名",
"repo": "插件仓库链接", "repo": "插件仓库链接",
"tags": [], "tags": [],
"social_link": "", "social_link": ""
} }
``` ```
validations: validations:

View File

@@ -6,13 +6,13 @@ body:
- type: markdown - type: markdown
attributes: attributes:
value: | value: |
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。
- type: textarea - type: textarea
attributes: attributes:
label: 发生了什么 label: 发生了什么
description: 描述你遇到的异常 description: 描述你遇到的异常
placeholder: > placeholder: >
一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。 一个清晰且具体的描述这个异常是什么。
validations: validations:
required: true required: true
@@ -55,7 +55,7 @@ body:
attributes: attributes:
label: 报错日志 label: 报错日志
description: > description: >
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。 如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!
placeholder: > placeholder: >
请提供完整的报错日志或截图。 请提供完整的报错日志或截图。
validations: validations:

View File

@@ -1,46 +1,19 @@
<!-- 如果有的话,指定 PR 旨在解决的 ISSUE 编号。 --> <!-- 如果有的话,指定这个 PR 解决的 ISSUE -->
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. --> 解决了 #XYZ
fixes #XYZ ### Motivation
--- <!--解释为什么要改动-->
### Motivation / 动机 ### Modifications
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)--> <!--简单解释你的改动-->
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
### Modifications / 改动点 ### Check
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?--> <!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
### Verification Steps / 验证步骤 - [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 👀 我的更改经过良好的测试
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤例如1. 导航到... 2. 点击...)。--> - [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).--> - [ ] 😮 我的更改没有引入恶意代码
### Screenshots or Test Results / 运行截图或测试结果
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
### Compatibility & Breaking Changes / 兼容性与破坏性变更
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
---
### Checklist / 检查清单
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.

View File

@@ -1,38 +0,0 @@
# Set to true to add reviewers to pull requests
addReviewers: true
# Set to true to add assignees to pull requests
addAssignees: false
# A list of reviewers to be added to pull requests (GitHub user name)
reviewers:
- Soulter
- Raven95676
- Larch-C
- anka-afk
- advent259141
- Fridemn
- LIghtJUNction
# - zouyonghe
# A number of reviewers added to the pull request
# Set 0 to add all the reviewers (default: 0)
numberOfReviewers: 2
# A list of assignees, overrides reviewers if set
# assignees:
# - assigneeA
# A number of assignees to add to the pull request
# Set to 0 to add all of the assignees.
# Uses numberOfReviewers if unset.
# numberOfAssignees: 2
# A list of keywords to be skipped the process that add reviewers if pull requests include it
skipKeywords:
- wip
- draft
# A list of users to be skipped by both the add reviewers and add assignees processes
# skipUsers:
# - dependabot[bot]

View File

@@ -73,7 +73,7 @@ jobs:
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v6 uses: actions/setup-python@v5
with: with:
python-version: '3.10' python-version: '3.10'

View File

@@ -1,34 +0,0 @@
name: Code Format Check
on:
pull_request:
branches: [ master ]
push:
branches: [ master ]
jobs:
format-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
- name: Install UV
run: pip install uv
- name: Install dependencies
run: uv sync
- name: Check code formatting with ruff
run: |
uv run ruff format --check .
- name: Check code style with ruff
run: |
uv run ruff check .

View File

@@ -60,7 +60,7 @@ jobs:
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL
uses: github/codeql-action/init@v4 uses: github/codeql-action/init@v3
with: with:
languages: ${{ matrix.language }} languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }} build-mode: ${{ matrix.build-mode }}
@@ -88,6 +88,6 @@ jobs:
exit 1 exit 1
- name: Perform CodeQL Analysis - name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4 uses: github/codeql-action/analyze@v3
with: with:
category: "/language:${{matrix.language}}" category: "/language:${{matrix.language}}"

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v6 uses: actions/setup-python@v5
- name: Install dependencies - name: Install dependencies
run: | run: |

View File

@@ -13,18 +13,11 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 'latest'
- name: npm install, build - name: npm install, build
run: | run: |
cd dashboard cd dashboard
npm install pnpm -g npm install
pnpm install npm run build
pnpm i --save-dev @types/markdown-it
pnpm run build
- name: Inject Commit SHA - name: Inject Commit SHA
id: get_sha id: get_sha
@@ -36,7 +29,7 @@ jobs:
zip -r dist.zip dist zip -r dist.zip dist
- name: Archive production artifacts - name: Archive production artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v4
with: with:
name: dist-without-markdown name: dist-without-markdown
path: | path: |
@@ -44,7 +37,6 @@ jobs:
!dist/**/*.md !dist/**/*.md
- name: Create GitHub Release - name: Create GitHub Release
if: github.event_name == 'push'
uses: ncipollo/release-action@v1 uses: ncipollo/release-action@v1
with: with:
tag: release-${{ github.sha }} tag: release-${{ github.sha }}

View File

@@ -27,33 +27,6 @@ jobs:
if: github.event_name == 'workflow_dispatch' if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }} run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Check if version is pre-release
id: check-prerelease
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
else
version="${{ github.ref_name }}"
fi
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
echo "is_prerelease=true" >> $GITHUB_OUTPUT
echo "Version $version is a pre-release, will not push latest tag"
else
echo "is_prerelease=false" >> $GITHUB_OUTPUT
echo "Version $version is a stable release, will push latest tag"
fi
- name: Build Dashboard
run: |
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
- name: Set QEMU - name: Set QEMU
uses: docker/setup-qemu-action@v3 uses: docker/setup-qemu-action@v3
@@ -80,9 +53,9 @@ jobs:
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
push: true push: true
tags: | tags: |
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }} ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }} ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }} ghcr.io/soulter/astrbot:latest
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }} ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- name: Post build notifications - name: Post build notifications

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write pull-requests: write
steps: steps:
- uses: actions/stale@v10 - uses: actions/stale@v9
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message' stale-issue-message: 'Stale issue message'

52
.gitignore vendored
View File

@@ -1,49 +1,33 @@
# Python related
__pycache__ __pycache__
.mypy_cache
.venv*
.conda/
uv.lock
.coverage
# IDE and editors
.vscode
.idea
# Logs and temporary files
botpy.log botpy.log
logs/ .vscode
temp .venv*
cookies.json .idea
# Data files
data_v2.db data_v2.db
data_v3.db data_v3.db
data
configs/session configs/session
configs/config.yaml configs/config.yaml
**/.DS_Store
temp
cmd_config.json cmd_config.json
data
# Plugins and packages cookies.json
logs/
addons/plugins addons/plugins
packages/python_interpreter/workplace .coverage
tests/astrbot_plugin_openai
# Dashboard
tests/astrbot_plugin_openai
chroma
dashboard/node_modules/ dashboard/node_modules/
dashboard/dist/ dashboard/dist/
.DS_Store
package-lock.json package-lock.json
package.json package.json
# Operating System
**/.DS_Store
.DS_Store
# AstrBot specific
.astrbot
astrbot.lock
# Other
chroma
venv/* venv/*
packages/python_interpreter/workplace
.venv/*
.conda/
.idea
pytest.ini pytest.ini
.astrbot

View File

@@ -6,20 +6,8 @@ ci:
autoupdate_schedule: weekly autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate" autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. rev: v0.11.2
rev: v0.14.1 hooks:
hooks: - id: ruff
# Run the linter. - id: ruff-format
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]

View File

@@ -4,6 +4,8 @@ WORKDIR /AstrBot
COPY . /AstrBot/ COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
nodejs \
npm \
gcc \ gcc \
build-essential \ build-essential \
python3-dev \ python3-dev \
@@ -11,22 +13,23 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libssl-dev \ libssl-dev \
ca-certificates \ ca-certificates \
bash \ bash \
ffmpeg \
curl \
gnupg \
git \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && rm -rf /var/lib/apt/lists/*
RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \ RUN python -m pip install uv
apt-get install -y --no-install-recommends nodejs && \ RUN uv pip install -r requirements.txt --no-cache-dir --system
echo "3.11" > .python-version && \ RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
rm -rf /var/lib/apt/lists/*
RUN python -m pip install --no-cache-dir uv && \ # 释出 ffmpeg
uv pip install socksio pilk --no-cache-dir --system RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
EXPOSE 6185 EXPOSE 6185
EXPOSE 6186 EXPOSE 6186
CMD ["uv", "run", "main.py"] CMD [ "python", "main.py" ]

View File

@@ -1,4 +1,4 @@
FROM python:3.11-slim FROM python:3.10-slim
WORKDIR /AstrBot WORKDIR /AstrBot
@@ -14,27 +14,22 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
unzip \ unzip \
ca-certificates \ ca-certificates \
bash \ bash \
git \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && rm -rf /var/lib/apt/lists/*
ENV NVM_DIR="/root/.nvm" \ # Installation of Node.js
NODE_VERSION=22 ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \ RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \ . "$NVM_DIR/nvm.sh" && \
nvm install $NODE_VERSION && \ nvm install 22 && \
nvm use $NODE_VERSION && \ nvm use 22
nvm alias default $NODE_VERSION && \ RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
node -v && npm -v && \
echo "3.11" > .python-version
ENV PATH="$NVM_DIR/versions/node/v$(node -v | cut -d 'v' -f 2)/bin:$PATH"
RUN python -m pip install --no-cache-dir uv RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
# 安装项目依赖(根据指南,使用 uv sync RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
RUN uv sync --no-cache
EXPOSE 6185 EXPOSE 6185
EXPOSE 6186 EXPOSE 6186
CMD ["uv", "run", "main.py"] CMD ["python", "main.py"]

201
README.md
View File

@@ -1,40 +1,33 @@
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9) ![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p> </p>
<div align="center"> <div align="center">
<br> _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br> [![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python"> <img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a> <a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a> <a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a> <a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600"> [![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7日消息量&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div> </div>
<br> AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> ## ✨ 主要功能
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路线图</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
## 主要功能
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。 1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。 2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
@@ -42,9 +35,9 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。 4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
5. **WebUI**。可视化配置和管理机器人,功能齐全。 5. **WebUI**。可视化配置和管理机器人,功能齐全。
## 部署方式 ## ✨ 使用方式
#### Docker 部署(推荐 🥳) #### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。 推荐使用 Docker / Docker Compose 方式部署 AstrBot。
@@ -72,7 +65,7 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
社区贡献的部署方式。 社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### Windows 一键安装器部署 #### Windows 一键安装器部署
@@ -86,7 +79,9 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
#### 手动部署 #### 手动部署
首先安装 uv > 推荐使用 `uv`。
首先,安装 uv
```bash ```bash
pip install uv pip install uv
@@ -101,101 +96,53 @@ uv run main.py
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。 或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## 🌍 社区
### QQ 群组
- 1 群322154837
- 3 群630166526
- 5 群822130018
- 6 群753075035
- 开发者群975206796
### Telegram 群组
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群组
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ⚡ 消息平台支持情况 ## ⚡ 消息平台支持情况
**官方维护**
| 平台 | 支持性 | | 平台 | 支持性 |
| -------- | ------- | | -------- | ------- |
| QQ(官方平台) | ✔ | | QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ | | QQ(OneBot) | ✔ |
| Telegram | ✔ | | Telegram | ✔ |
| 企微应用 | ✔ | | 企业微信 | ✔ |
| 企微智能机器人 | ✔ |
| 微信客服 | ✔ | | 微信客服 | ✔ |
| 微信公众号 | ✔ | | 微信公众号 | ✔ |
| 飞书 | ✔ | | 飞书 | ✔ |
| 钉钉 | ✔ | | 钉钉 | ✔ |
| Slack | ✔ | | Slack | ✔ |
| Discord | ✔ | | Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ | | [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ | | [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | | | 微信对话开放平台 | 🚧 |
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | | | WhatsApp | 🚧 |
| 小爱音响 | 🚧 |
## ⚡ 提供商支持情况 ## ⚡ 提供商支持情况
**大模型服务** | 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
## ❤️ 贡献 ## ❤️ 贡献
@@ -210,11 +157,44 @@ uv run main.py
AstrBot 使用 `ruff` 进行代码格式化和检查。 AstrBot 使用 `ruff` 进行代码格式化和检查。
```bash ```bash
git clone https://github.com/AstrBotDevs/AstrBot git clone https://github.com/Soulter/AstrBot
pip install pre-commit pip install pre-commit
pre-commit install pre-commit install
``` ```
## 🌟 支持
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
## ✨ Demo
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器Beta 测试✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
_✨ WebUI ✨_
</div>
</details>
## ❤️ Special Thanks ## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️ 特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
@@ -223,21 +203,24 @@ pre-commit install
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" /> <img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a> </a>
此外,本项目的诞生离不开以下开源项目的帮助 此外,本项目的诞生离不开以下开源项目:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架 - [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History ## ⭐ Star History
> [!TIP] > [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3 > 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center"> <div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date) [![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div> </div>
</details> ![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097)
_私は、高性能ですから!_ _私は、高性能ですから!_

View File

@@ -10,16 +10,16 @@ _✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest) [![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python"> <img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a> <a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a> <a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) [![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot) [![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.app/">Documentation</a> <a href="https://astrbot.app/">Documentation</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a> <a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
</div> </div>
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities. AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
@@ -49,7 +49,7 @@ Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app
#### Replit Deployment #### Replit Deployment
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS Deployment #### CasaOS Deployment
@@ -67,8 +67,8 @@ See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images | | QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice | | QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice | | WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images | | [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice | | [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| Feishu | ✔ | Group chats | Text, Images | | Feishu | ✔ | Group chats | Text, Images |
| WeChat Open Platform | 🚧 | Planned | - | | WeChat Open Platform | 🚧 | Planned | - |
| Discord | 🚧 | Planned | - | | Discord | 🚧 | Planned | - |
@@ -157,7 +157,7 @@ _✨ Built-in Web Chat Interface ✨_
<div align="center"> <div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=AstrBotDevs/AstrBot&type=Date)](https://star-history.com/#AstrBotDevs/AstrBot&Date) [![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div> </div>
@@ -169,7 +169,7 @@ _✨ Built-in Web Chat Interface ✨_
<!-- ## ✨ ATRI [Beta] <!-- ## ✨ ATRI [Beta]
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri) Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data 1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
2. Long-term memory 2. Long-term memory

View File

@@ -10,16 +10,16 @@ _✨ 簡単に使えるマルチプラットフォーム LLM チャットボッ
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest) [![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python"> <img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a> <a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) [![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot) [![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.app/">ドキュメントを見る</a> <a href="https://astrbot.app/">ドキュメントを見る</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a> <a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
</div> </div>
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデルLLM接続機能を備えたチャットボットおよび開発フレームワークです。 AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデルLLM接続機能を備えたチャットボットおよび開発フレームワークです。
@@ -50,7 +50,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
#### Replit デプロイ #### Replit デプロイ
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) [![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS デプロイ #### CasaOS デプロイ

View File

@@ -1,19 +1,20 @@
from astrbot import logger
from astrbot.core import html_renderer, sp
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.register import register_agent as agent from astrbot import logger
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_agent as agent
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
__all__ = [ __all__ = [
"AstrBotConfig", "AstrBotConfig",
"BaseFunctionToolExecutor", "logger",
"FunctionTool",
"ToolSet",
"agent",
"html_renderer", "html_renderer",
"llm_tool", "llm_tool",
"logger", "agent",
"sp", "sp",
"ToolSet",
"FunctionTool",
"BaseFunctionToolExecutor",
] ]

View File

@@ -1,17 +1,18 @@
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageEventResult,
MessageChain,
CommandResult, CommandResult,
EventResultType, EventResultType,
MessageChain,
MessageEventResult,
ResultContentType, ResultContentType,
) )
from astrbot.core.platform import AstrMessageEvent from astrbot.core.platform import AstrMessageEvent
__all__ = [ __all__ = [
"AstrMessageEvent", "MessageEventResult",
"MessageChain",
"CommandResult", "CommandResult",
"EventResultType", "EventResultType",
"MessageChain", "AstrMessageEvent",
"MessageEventResult",
"ResultContentType", "ResultContentType",
] ]

View File

@@ -1,52 +1,49 @@
from astrbot.core.star.filter.custom_filter import CustomFilter
from astrbot.core.star.filter.event_message_type import (
EventMessageType,
EventMessageTypeFilter,
)
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
PlatformAdapterTypeFilter,
)
from astrbot.core.star.register import register_after_message_sent as after_message_sent
from astrbot.core.star.register import register_command as command
from astrbot.core.star.register import register_command_group as command_group
from astrbot.core.star.register import register_custom_filter as custom_filter
from astrbot.core.star.register import register_event_message_type as event_message_type
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_on_astrbot_loaded as on_astrbot_loaded
from astrbot.core.star.register import (
register_on_decorating_result as on_decorating_result,
)
from astrbot.core.star.register import register_on_llm_request as on_llm_request
from astrbot.core.star.register import register_on_llm_response as on_llm_response
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
from astrbot.core.star.register import register_permission_type as permission_type
from astrbot.core.star.register import ( from astrbot.core.star.register import (
register_command as command,
register_command_group as command_group,
register_event_message_type as event_message_type,
register_regex as regex,
register_platform_adapter_type as platform_adapter_type, register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent,
) )
from astrbot.core.star.register import register_regex as regex
from astrbot.core.star.filter.event_message_type import (
EventMessageTypeFilter,
EventMessageType,
)
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterTypeFilter,
PlatformAdapterType,
)
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
from astrbot.core.star.filter.custom_filter import CustomFilter
__all__ = [ __all__ = [
"CustomFilter",
"EventMessageType",
"EventMessageTypeFilter",
"PermissionType",
"PermissionTypeFilter",
"PlatformAdapterType",
"PlatformAdapterTypeFilter",
"after_message_sent",
"command", "command",
"command_group", "command_group",
"custom_filter",
"event_message_type", "event_message_type",
"llm_tool",
"on_astrbot_loaded",
"on_decorating_result",
"on_llm_request",
"on_llm_response",
"on_platform_loaded",
"permission_type",
"platform_adapter_type",
"regex", "regex",
"platform_adapter_type",
"permission_type",
"EventMessageTypeFilter",
"EventMessageType",
"PlatformAdapterTypeFilter",
"PlatformAdapterType",
"PermissionTypeFilter",
"CustomFilter",
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",
"after_message_sent",
"on_llm_response",
] ]

View File

@@ -1,22 +1,23 @@
from astrbot.core.message.components import *
from astrbot.core.platform import ( from astrbot.core.platform import (
AstrBotMessage,
AstrMessageEvent, AstrMessageEvent,
Group, Platform,
AstrBotMessage,
MessageMember, MessageMember,
MessageType, MessageType,
Platform,
PlatformMetadata, PlatformMetadata,
Group,
) )
from astrbot.core.platform.register import register_platform_adapter from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
__all__ = [ __all__ = [
"AstrBotMessage",
"AstrMessageEvent", "AstrMessageEvent",
"Group", "Platform",
"AstrBotMessage",
"MessageMember", "MessageMember",
"MessageType", "MessageType",
"Platform",
"PlatformMetadata", "PlatformMetadata",
"register_platform_adapter", "register_platform_adapter",
"Group",
] ]

View File

@@ -1,17 +1,17 @@
from astrbot.core.provider import Personality, Provider, STTProvider from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entities import ( from astrbot.core.provider.entities import (
LLMResponse,
ProviderMetaData,
ProviderRequest, ProviderRequest,
ProviderType, ProviderType,
ProviderMetaData,
LLMResponse,
) )
__all__ = [ __all__ = [
"LLMResponse",
"Personality",
"Provider", "Provider",
"ProviderMetaData", "STTProvider",
"Personality",
"ProviderRequest", "ProviderRequest",
"ProviderType", "ProviderType",
"STTProvider", "ProviderMetaData",
"LLMResponse",
] ]

View File

@@ -1,7 +1,8 @@
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
from astrbot.core.star.register import ( from astrbot.core.star.register import (
register_star as register, # 注册插件Star register_star as register, # 注册插件Star
) )
__all__ = ["Context", "Star", "StarTools", "register"] from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
__all__ = ["register", "Context", "Star", "StarTools"]

View File

@@ -1,7 +1,7 @@
from astrbot.core.utils.session_waiter import ( from astrbot.core.utils.session_waiter import (
SessionController,
SessionWaiter, SessionWaiter,
SessionController,
session_waiter, session_waiter,
) )
__all__ = ["SessionController", "SessionWaiter", "session_waiter"] __all__ = ["SessionWaiter", "SessionController", "session_waiter"]

View File

@@ -1,11 +1,11 @@
"""AstrBot CLI入口""" """
AstrBot CLI入口
import sys """
import click import click
import sys
from . import __version__ from . import __version__
from .commands import conf, init, plug, run from .commands import init, run, plug, conf
logo_tmpl = r""" logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________. ___ _______.___________..______ .______ ______ .___________.

View File

@@ -1,6 +1,6 @@
from .cmd_conf import conf
from .cmd_init import init from .cmd_init import init
from .cmd_plug import plug
from .cmd_run import run from .cmd_run import run
from .cmd_plug import plug
from .cmd_conf import conf
__all__ = ["conf", "init", "plug", "run"] __all__ = ["init", "run", "plug", "conf"]

View File

@@ -1,12 +1,9 @@
import hashlib
import json import json
import zoneinfo
from collections.abc import Callable
from typing import Any
import click import click
import hashlib
from ..utils import check_astrbot_root, get_astrbot_root import zoneinfo
from typing import Any, Callable
from ..utils import get_astrbot_root, check_astrbot_root
def _validate_log_level(value: str) -> str: def _validate_log_level(value: str) -> str:
@@ -14,7 +11,7 @@ def _validate_log_level(value: str) -> str:
value = value.upper() value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException( raise click.ClickException(
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一", "日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
) )
return value return value
@@ -76,7 +73,7 @@ def _load_config() -> dict[str, Any]:
root = get_astrbot_root() root = get_astrbot_root()
if not check_astrbot_root(root): if not check_astrbot_root(root):
raise click.ClickException( raise click.ClickException(
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
) )
config_path = root / "data" / "cmd_config.json" config_path = root / "data" / "cmd_config.json"
@@ -91,7 +88,7 @@ def _load_config() -> dict[str, Any]:
try: try:
return json.loads(config_path.read_text(encoding="utf-8-sig")) return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise click.ClickException(f"配置文件解析失败: {e!s}") raise click.ClickException(f"配置文件解析失败: {str(e)}")
def _save_config(config: dict[str, Any]) -> None: def _save_config(config: dict[str, Any]) -> None:
@@ -99,8 +96,7 @@ def _save_config(config: dict[str, Any]) -> None:
config_path = get_astrbot_root() / "data" / "cmd_config.json" config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text( config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2), json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
encoding="utf-8-sig",
) )
@@ -112,7 +108,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
obj[part] = {} obj[part] = {}
elif not isinstance(obj[part], dict): elif not isinstance(obj[part], dict):
raise click.ClickException( raise click.ClickException(
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典", f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
) )
obj = obj[part] obj = obj[part]
obj[parts[-1]] = value obj[parts[-1]] = value
@@ -144,6 +140,7 @@ def conf():
- callback_api_base: 回调接口基址 - callback_api_base: 回调接口基址
""" """
pass
@conf.command(name="set") @conf.command(name="set")
@@ -151,7 +148,7 @@ def conf():
@click.argument("value") @click.argument("value")
def set_config(key: str, value: str): def set_config(key: str, value: str):
"""设置配置项的值""" """设置配置项的值"""
if key not in CONFIG_VALIDATORS: if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}") raise click.ClickException(f"不支持的配置项: {key}")
config = _load_config() config = _load_config()
@@ -173,17 +170,17 @@ def set_config(key: str, value: str):
except KeyError: except KeyError:
raise click.ClickException(f"未知的配置项: {key}") raise click.ClickException(f"未知的配置项: {key}")
except Exception as e: except Exception as e:
raise click.UsageError(f"设置配置失败: {e!s}") raise click.UsageError(f"设置配置失败: {str(e)}")
@conf.command(name="get") @conf.command(name="get")
@click.argument("key", required=False) @click.argument("key", required=False)
def get_config(key: str | None = None): def get_config(key: str = None):
"""获取配置项的值不提供key则显示所有可配置项""" """获取配置项的值不提供key则显示所有可配置项"""
config = _load_config() config = _load_config()
if key: if key:
if key not in CONFIG_VALIDATORS: if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}") raise click.ClickException(f"不支持的配置项: {key}")
try: try:
@@ -194,10 +191,10 @@ def get_config(key: str | None = None):
except KeyError: except KeyError:
raise click.ClickException(f"未知的配置项: {key}") raise click.ClickException(f"未知的配置项: {key}")
except Exception as e: except Exception as e:
raise click.UsageError(f"获取配置失败: {e!s}") raise click.UsageError(f"获取配置失败: {str(e)}")
else: else:
click.echo("当前配置:") click.echo("当前配置:")
for key in CONFIG_VALIDATORS: for key in CONFIG_VALIDATORS.keys():
try: try:
value = ( value = (
"********" "********"

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
from pathlib import Path
import click import click
from filelock import FileLock, Timeout from filelock import FileLock, Timeout
@@ -7,14 +6,14 @@ from filelock import FileLock, Timeout
from ..utils import check_dashboard, get_astrbot_root from ..utils import check_dashboard, get_astrbot_root
async def initialize_astrbot(astrbot_root: Path) -> None: async def initialize_astrbot(astrbot_root) -> None:
"""执行 AstrBot 初始化逻辑""" """执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot" dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists(): if not dot_astrbot.exists():
click.echo(f"Current Directory: {astrbot_root}") click.echo(f"Current Directory: {astrbot_root}")
click.echo( click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。", "如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
) )
if click.confirm( if click.confirm(
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}", f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",

View File

@@ -1,29 +1,31 @@
import re import re
import shutil
from pathlib import Path from pathlib import Path
import click import click
import shutil
from ..utils import ( from ..utils import (
PluginStatus, get_git_repo,
build_plug_list, build_plug_list,
manage_plugin,
PluginStatus,
check_astrbot_root, check_astrbot_root,
get_astrbot_root, get_astrbot_root,
get_git_repo,
manage_plugin,
) )
@click.group() @click.group()
def plug(): def plug():
"""插件管理""" """插件管理"""
pass
def _get_data_path() -> Path: def _get_data_path() -> Path:
base = get_astrbot_root() base = get_astrbot_root()
if not check_astrbot_root(base): if not check_astrbot_root(base):
raise click.ClickException( raise click.ClickException(
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
) )
return (base / "data").resolve() return (base / "data").resolve()
@@ -39,7 +41,7 @@ def display_plugins(plugins, title=None, color=None):
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "") desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
click.echo( click.echo(
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} " f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
f"{p['author']:<15} {desc:<30}", f"{p['author']:<15} {desc:<30}"
) )
@@ -76,7 +78,7 @@ def new(name: str):
f"desc: {desc}\n" f"desc: {desc}\n"
f"version: {version}\n" f"version: {version}\n"
f"author: {author}\n" f"author: {author}\n"
f"repo: {repo}\n", f"repo: {repo}\n"
) )
# 重写 README.md # 重写 README.md
@@ -84,7 +86,7 @@ def new(name: str):
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n") f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
# 重写 main.py # 重写 main.py
with open(plug_path / "main.py", encoding="utf-8") as f: with open(plug_path / "main.py", "r", encoding="utf-8") as f:
content = f.read() content = f.read()
new_content = content.replace( new_content = content.replace(

View File

@@ -1,18 +1,19 @@
import asyncio
import os import os
import sys import sys
import traceback
from pathlib import Path from pathlib import Path
import click import click
import asyncio
import traceback
from filelock import FileLock, Timeout from filelock import FileLock, Timeout
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
async def run_astrbot(astrbot_root: Path): async def run_astrbot(astrbot_root: Path):
"""运行 AstrBot""" """运行 AstrBot"""
from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core import logger, LogManager, LogBroker, db_helper
from astrbot.core.initial_loader import InitialLoader from astrbot.core.initial_loader import InitialLoader
await check_dashboard(astrbot_root / "data") await check_dashboard(astrbot_root / "data")
@@ -37,7 +38,7 @@ def run(reload: bool, port: str) -> None:
if not check_astrbot_root(astrbot_root): if not check_astrbot_root(astrbot_root):
raise click.ClickException( raise click.ClickException(
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
) )
os.environ["ASTRBOT_ROOT"] = str(astrbot_root) os.environ["ASTRBOT_ROOT"] = str(astrbot_root)

View File

@@ -1,18 +1,18 @@
from .basic import ( from .basic import (
get_astrbot_root,
check_astrbot_root, check_astrbot_root,
check_dashboard, check_dashboard,
get_astrbot_root,
) )
from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
from .version_comparator import VersionComparator from .version_comparator import VersionComparator
__all__ = [ __all__ = [
"PluginStatus", "get_astrbot_root",
"VersionComparator",
"build_plug_list",
"check_astrbot_root", "check_astrbot_root",
"check_dashboard", "check_dashboard",
"get_astrbot_root",
"get_git_repo", "get_git_repo",
"manage_plugin", "manage_plugin",
"build_plug_list",
"VersionComparator",
"PluginStatus",
] ]

View File

@@ -21,9 +21,8 @@ def get_astrbot_root() -> Path:
async def check_dashboard(astrbot_root: Path) -> None: async def check_dashboard(astrbot_root: Path) -> None:
"""检查是否安装了dashboard""" """检查是否安装了dashboard"""
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
from astrbot.core.config.default import VERSION from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from .version_comparator import VersionComparator from .version_comparator import VersionComparator
try: try:
@@ -38,10 +37,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
): ):
click.echo("正在安装管理面板...") click.echo("正在安装管理面板...")
await download_dashboard( await download_dashboard(
path="data/dashboard.zip", path="data/dashboard.zip", extract_path=str(astrbot_root)
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
) )
click.echo("管理面板安装完成") click.echo("管理面板安装完成")
@@ -49,26 +45,21 @@ async def check_dashboard(astrbot_root: Path) -> None:
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0: if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("管理面板已是最新版本") click.echo("管理面板已是最新版本")
return return
try: else:
version = dashboard_version.split("v")[1] try:
click.echo(f"管理面板版本: {version}") version = dashboard_version.split("v")[1]
await download_dashboard( click.echo(f"管理面板版本: {version}")
path="data/dashboard.zip", await download_dashboard(
extract_path=str(astrbot_root), path="data/dashboard.zip", extract_path=str(astrbot_root)
version=f"v{VERSION}", )
latest=False, except Exception as e:
) click.echo(f"下载管理面板失败: {e}")
except Exception as e: return
click.echo(f"下载管理面板失败: {e}")
return
except FileNotFoundError: except FileNotFoundError:
click.echo("初始化管理面板目录...") click.echo("初始化管理面板目录...")
try: try:
await download_dashboard( await download_dashboard(
path=str(astrbot_root / "dashboard.zip"), path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
) )
click.echo("管理面板初始化完成") click.echo("管理面板初始化完成")
except Exception as e: except Exception as e:

View File

@@ -1,14 +1,14 @@
import shutil import shutil
import tempfile import tempfile
import httpx
import yaml
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from zipfile import ZipFile from zipfile import ZipFile
import click import click
import httpx
import yaml
from .version_comparator import VersionComparator from .version_comparator import VersionComparator
@@ -32,8 +32,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
release_url = f"https://api.github.com/repos/{author}/{repo}/releases" release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try: try:
with httpx.Client( with httpx.Client(
proxy=proxy if proxy else None, proxy=proxy if proxy else None, follow_redirects=True
follow_redirects=True,
) as client: ) as client:
resp = client.get(release_url) resp = client.get(release_url)
resp.raise_for_status() resp.raise_for_status()
@@ -56,8 +55,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
# 下载并解压 # 下载并解压
with httpx.Client( with httpx.Client(
proxy=proxy if proxy else None, proxy=proxy if proxy else None, follow_redirects=True
follow_redirects=True,
) as client: ) as client:
resp = client.get(download_url) resp = client.get(download_url)
if ( if (
@@ -91,7 +89,6 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
Returns: Returns:
dict: 包含元数据的字典,如果读取失败则返回空字典 dict: 包含元数据的字典,如果读取失败则返回空字典
""" """
yaml_path = plugin_dir / "metadata.yaml" yaml_path = plugin_dir / "metadata.yaml"
if yaml_path.exists(): if yaml_path.exists():
@@ -110,7 +107,6 @@ def build_plug_list(plugins_dir: Path) -> list:
Returns: Returns:
list: 包含插件信息的字典列表 list: 包含插件信息的字典列表
""" """
# 获取本地插件信息 # 获取本地插件信息
result = [] result = []
@@ -128,17 +124,15 @@ def build_plug_list(plugins_dir: Path) -> list:
if metadata and all( if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"] k in metadata for k in ["name", "desc", "version", "author", "repo"]
): ):
result.append( result.append({
{ "name": str(metadata.get("name", "")),
"name": str(metadata.get("name", "")), "desc": str(metadata.get("desc", "")),
"desc": str(metadata.get("desc", "")), "version": str(metadata.get("version", "")),
"version": str(metadata.get("version", "")), "author": str(metadata.get("author", "")),
"author": str(metadata.get("author", "")), "repo": str(metadata.get("repo", "")),
"repo": str(metadata.get("repo", "")), "status": PluginStatus.INSTALLED,
"status": PluginStatus.INSTALLED, "local_path": str(plugin_dir),
"local_path": str(plugin_dir), })
},
)
# 获取在线插件列表 # 获取在线插件列表
online_plugins = [] online_plugins = []
@@ -148,17 +142,15 @@ def build_plug_list(plugins_dir: Path) -> list:
resp.raise_for_status() resp.raise_for_status()
data = resp.json() data = resp.json()
for plugin_id, plugin_info in data.items(): for plugin_id, plugin_info in data.items():
online_plugins.append( online_plugins.append({
{ "name": str(plugin_id),
"name": str(plugin_id), "desc": str(plugin_info.get("desc", "")),
"desc": str(plugin_info.get("desc", "")), "version": str(plugin_info.get("version", "")),
"version": str(plugin_info.get("version", "")), "author": str(plugin_info.get("author", "")),
"author": str(plugin_info.get("author", "")), "repo": str(plugin_info.get("repo", "")),
"repo": str(plugin_info.get("repo", "")), "status": PluginStatus.NOT_INSTALLED,
"status": PluginStatus.NOT_INSTALLED, "local_path": None,
"local_path": None, })
},
)
except Exception as e: except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True) click.echo(f"获取在线插件列表失败: {e}", err=True)
@@ -172,8 +164,7 @@ def build_plug_list(plugins_dir: Path) -> list:
) )
if ( if (
VersionComparator.compare_version( VersionComparator.compare_version(
local_plugin["version"], local_plugin["version"], online_plugin["version"]
online_plugin["version"],
) )
< 0 < 0
): ):
@@ -191,10 +182,7 @@ def build_plug_list(plugins_dir: Path) -> list:
def manage_plugin( def manage_plugin(
plugin: dict, plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
plugins_dir: Path,
is_update: bool = False,
proxy: str | None = None,
) -> None: ) -> None:
"""安装或更新插件 """安装或更新插件
@@ -203,7 +191,6 @@ def manage_plugin(
plugins_dir (Path): 插件目录 plugins_dir (Path): 插件目录
is_update (bool, optional): 是否为更新操作. 默认为 False is_update (bool, optional): 是否为更新操作. 默认为 False
proxy (str, optional): 代理服务器地址 proxy (str, optional): 代理服务器地址
""" """
plugin_name = plugin["name"] plugin_name = plugin["name"]
repo_url = plugin["repo"] repo_url = plugin["repo"]
@@ -221,26 +208,26 @@ def manage_plugin(
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新") raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
# 备份现有插件 # 备份现有插件
if is_update and backup_path is not None and backup_path.exists(): if is_update and backup_path.exists():
shutil.rmtree(backup_path) shutil.rmtree(backup_path)
if is_update and backup_path is not None: if is_update:
shutil.copytree(target_path, backup_path) shutil.copytree(target_path, backup_path)
try: try:
click.echo( click.echo(
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...", f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
) )
get_git_repo(repo_url, target_path, proxy) get_git_repo(repo_url, target_path, proxy)
# 更新成功,删除备份 # 更新成功,删除备份
if is_update and backup_path is not None and backup_path.exists(): if is_update and backup_path.exists():
shutil.rmtree(backup_path) shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功") click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e: except Exception as e:
if target_path.exists(): if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True) shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path is not None and backup_path.exists(): if is_update and backup_path.exists():
shutil.move(backup_path, target_path) shutil.move(backup_path, target_path)
raise click.ClickException( raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}", f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
) )

View File

@@ -1,4 +1,6 @@
"""拷贝自 astrbot.core.utils.version_comparator""" """
拷贝自 astrbot.core.utils.version_comparator
"""
import re import re
@@ -40,15 +42,15 @@ class VersionComparator:
for i in range(length): for i in range(length):
if v1_parts[i] > v2_parts[i]: if v1_parts[i] > v2_parts[i]:
return 1 return 1
if v1_parts[i] < v2_parts[i]: elif v1_parts[i] < v2_parts[i]:
return -1 return -1
# 比较预发布标签 # 比较预发布标签
if v1_prerelease is None and v2_prerelease is not None: if v1_prerelease is None and v2_prerelease is not None:
return 1 # 没有预发布标签的版本高于有预发布标签的版本 return 1 # 没有预发布标签的版本高于有预发布标签的版本
if v1_prerelease is not None and v2_prerelease is None: elif v1_prerelease is not None and v2_prerelease is None:
return -1 # 有预发布标签的版本低于没有预发布标签的版本 return -1 # 有预发布标签的版本低于没有预发布标签的版本
if v1_prerelease is not None and v2_prerelease is not None: elif v1_prerelease is not None and v2_prerelease is not None:
len_pre = max(len(v1_prerelease), len(v2_prerelease)) len_pre = max(len(v1_prerelease), len(v2_prerelease))
for i in range(len_pre): for i in range(len_pre):
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
@@ -56,21 +58,21 @@ class VersionComparator:
if p1 is None and p2 is not None: if p1 is None and p2 is not None:
return -1 return -1
if p1 is not None and p2 is None: elif p1 is not None and p2 is None:
return 1 return 1
if isinstance(p1, int) and isinstance(p2, str): elif isinstance(p1, int) and isinstance(p2, str):
return -1 return -1
if isinstance(p1, str) and isinstance(p2, int): elif isinstance(p1, str) and isinstance(p2, int):
return 1 return 1
if isinstance(p1, int) and isinstance(p2, int): elif isinstance(p1, int) and isinstance(p2, int):
if p1 > p2: if p1 > p2:
return 1 return 1
if p1 < p2: elif p1 < p2:
return -1 return -1
elif isinstance(p1, str) and isinstance(p2, str): elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2: if p1 > p2:
return 1 return 1
if p1 < p2: elif p1 < p2:
return -1 return -1
return 0 # 预发布标签完全相同 return 0 # 预发布标签完全相同

View File

@@ -1,14 +1,12 @@
import os import os
from .log import LogManager, LogBroker # noqa
from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.file_token_service import FileTokenService
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.t2i.renderer import HtmlRenderer from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from .log import LogBroker, LogManager # noqa from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
from astrbot.core.file_token_service import FileTokenService
from .utils.astrbot_path import get_astrbot_data_path from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹 # 初始化数据存储文件夹

View File

@@ -1,14 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic
from .hooks import BaseAgentRunHooks
from .run_context import TContext
from .tool import FunctionTool from .tool import FunctionTool
from typing import Generic
from .run_context import TContext
from .hooks import BaseAgentRunHooks
@dataclass @dataclass
class Agent(Generic[TContext]): class Agent(Generic[TContext]):
name: str name: str
instructions: str | None = None instructions: str | None = None
tools: list[str | FunctionTool] | None = None tools: list[str, FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -1,18 +1,14 @@
from typing import Generic from typing import Generic
from .tool import FunctionTool
from .agent import Agent from .agent import Agent
from .run_context import TContext from .run_context import TContext
from .tool import FunctionTool
class HandoffTool(FunctionTool, Generic[TContext]): class HandoffTool(FunctionTool, Generic[TContext]):
"""Handoff tool for delegating tasks to another agent.""" """Handoff tool for delegating tasks to another agent."""
def __init__( def __init__(
self, self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
agent: Agent[TContext],
parameters: dict | None = None,
**kwargs,
): ):
self.agent = agent self.agent = agent
super().__init__( super().__init__(

View File

@@ -1,13 +1,12 @@
from typing import Generic
import mcp import mcp
from dataclasses import dataclass
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.provider.entities import LLMResponse
from .run_context import ContextWrapper, TContext from .run_context import ContextWrapper, TContext
from typing import Generic
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.agent.tool import FunctionTool
@dataclass
class BaseAgentRunHooks(Generic[TContext]): class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start( async def on_tool_start(
@@ -24,7 +23,5 @@ class BaseAgentRunHooks(Generic[TContext]):
tool_result: mcp.types.CallToolResult | None, tool_result: mcp.types.CallToolResult | None,
): ... ): ...
async def on_agent_done( async def on_agent_done(
self, self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
run_context: ContextWrapper[TContext],
llm_response: LLMResponse,
): ... ): ...

View File

@@ -1,16 +1,11 @@
import asyncio import asyncio
import logging import logging
from contextlib import AsyncExitStack
from datetime import timedelta from datetime import timedelta
from typing import Generic from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger from astrbot import logger
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.utils.log_pipe import LogPipe from astrbot.core.utils.log_pipe import LogPipe
from .run_context import TContext
from .tool import FunctionTool
try: try:
import mcp import mcp
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
@@ -21,13 +16,13 @@ try:
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
logger.warning( logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。", "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
) )
def _prepare_config(config: dict) -> dict: def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式""" """准备配置,处理嵌套格式"""
if config.get("mcpServers"): if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"])) first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key] config = config["mcpServers"][first_key]
config.pop("active", None) config.pop("active", None)
@@ -45,15 +40,8 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
timeout = cfg.get("timeout", 10) timeout = cfg.get("timeout", 10)
try: try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
if transport_type == "streamable_http": if cfg.get("transport") == "streamable_http":
test_payload = { test_payload = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"method": "initialize", "method": "initialize",
@@ -76,7 +64,8 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response: ) as response:
if response.status == 200: if response.status == 200:
return True, "" return True, ""
return False, f"HTTP {response.status}: {response.reason}" else:
return False, f"HTTP {response.status}: {response.reason}"
else: else:
async with session.get( async with session.get(
url, url,
@@ -88,7 +77,8 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response: ) as response:
if response.status == 200: if response.status == 200:
return True, "" return True, ""
return False, f"HTTP {response.status}: {response.reason}" else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError: except asyncio.TimeoutError:
return False, f"连接超时: {timeout}" return False, f"连接超时: {timeout}"
@@ -99,10 +89,10 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
class MCPClient: class MCPClient:
def __init__(self): def __init__(self):
# Initialize session and client objects # Initialize session and client objects
self.session: mcp.ClientSession | None = None self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()
self.name: str | None = None self.name = None
self.active: bool = True self.active: bool = True
self.tools: list[mcp.Tool] = [] self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = [] self.server_errlogs: list[str] = []
@@ -118,7 +108,6 @@ class MCPClient:
Args: Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
""" """
cfg = _prepare_config(mcp_server_config.copy()) cfg = _prepare_config(mcp_server_config.copy())
@@ -132,14 +121,7 @@ class MCPClient:
if not success: if not success:
raise Exception(error_msg) raise Exception(error_msg)
if "transport" in cfg: if cfg.get("transport") != "streamable_http":
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
if transport_type != "streamable_http":
# SSE transport method # SSE transport method
self._streams_context = sse_client( self._streams_context = sse_client(
url=cfg["url"], url=cfg["url"],
@@ -148,22 +130,22 @@ class MCPClient:
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
) )
streams = await self.exit_stack.enter_async_context( streams = await self.exit_stack.enter_async_context(
self._streams_context, self._streams_context
) )
# Create a new client session # Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60)) read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context( self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession( mcp.ClientSession(
*streams, *streams,
read_timeout_seconds=read_timeout, read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore logging_callback=logging_callback, # type: ignore
), )
) )
else: else:
timeout = timedelta(seconds=cfg.get("timeout", 30)) timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta( sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5), seconds=cfg.get("sse_read_timeout", 60 * 5)
) )
self._streams_context = streamablehttp_client( self._streams_context = streamablehttp_client(
url=cfg["url"], url=cfg["url"],
@@ -173,18 +155,18 @@ class MCPClient:
terminate_on_close=cfg.get("terminate_on_close", True), terminate_on_close=cfg.get("terminate_on_close", True),
) )
read_s, write_s, _ = await self.exit_stack.enter_async_context( read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context, self._streams_context
) )
# Create a new client session # Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60)) read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context( self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession( mcp.ClientSession(
read_stream=read_s, read_stream=read_s,
write_stream=write_s, write_stream=write_s,
read_timeout_seconds=read_timeout, read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore logging_callback=logging_callback, # type: ignore
), )
) )
else: else:
@@ -210,14 +192,12 @@ class MCPClient:
# Create a new client session # Create a new client session
self.session = await self.exit_stack.enter_async_context( self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport), mcp.ClientSession(*stdio_transport)
) )
await self.session.initialize() await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult: async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools""" """List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools() response = await self.session.list_tools()
self.tools = response.tools self.tools = response.tools
return response return response
@@ -226,34 +206,3 @@ class MCPClient:
"""Clean up resources""" """Clean up resources"""
await self.exit_stack.aclose() await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done self.running_event.set() # Set the running event to indicate cleanup is done
class MCPTool(FunctionTool, Generic[TContext]):
"""A function tool that calls an MCP service."""
def __init__(
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
):
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
)
self.mcp_tool = mcp_tool
self.mcp_client = mcp_client
self.mcp_server_name = mcp_server_name
async def call(
self, context: ContextWrapper[TContext], **kwargs
) -> mcp.types.CallToolResult:
session = self.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=self.mcp_tool.name,
arguments=kwargs,
read_timeout_seconds=timedelta(
seconds=context.tool_call_timeout,
),
)
return res

View File

@@ -1,168 +0,0 @@
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
# License: Apache License 2.0
from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import core_schema
class ContentPart(BaseModel):
"""A part of the content in a message."""
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
type: str
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
type_value = getattr(cls, "type", None)
if type_value is None or not isinstance(type_value, str):
raise ValueError(invalid_subclass_error_msg)
cls.__content_part_registry[type_value] = cls
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# If we're dealing with the base ContentPart class, use custom validation
if cls.__name__ == "ContentPart":
def validate_content_part(value: Any) -> Any:
# if it's already an instance of a ContentPart subclass, return it
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
return value
# if it's a dict with a type field, dispatch to the appropriate subclass
if isinstance(value, dict) and "type" in value:
type_value: Any | None = cast(dict[str, Any], value).get("type")
if not isinstance(type_value, str):
raise ValueError(f"Cannot validate {value} as ContentPart")
target_class = cls.__content_part_registry[type_value]
return target_class.model_validate(value)
raise ValueError(f"Cannot validate {value} as ContentPart")
return core_schema.no_info_plain_validator_function(validate_content_part)
# for subclasses, use the default schema
return handler(source_type)
class TextPart(ContentPart):
"""
>>> TextPart(text="Hello, world!").model_dump()
{'type': 'text', 'text': 'Hello, world!'}
"""
type: str = "text"
text: str
class ImageURLPart(ContentPart):
"""
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
"""
class ImageURL(BaseModel):
url: str
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
id: str | None = None
"""The ID of the image, to allow LLMs to distinguish different images."""
type: str = "image_url"
image_url: str
class AudioURLPart(ContentPart):
"""
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
"""
class AudioURL(BaseModel):
url: str
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
id: str | None = None
"""The ID of the audio, to allow LLMs to distinguish different audios."""
type: str = "audio_url"
audio_url: AudioURL
class ToolCall(BaseModel):
"""
A tool call requested by the assistant.
>>> ToolCall(
... id="123",
... function=ToolCall.FunctionBody(
... name="function",
... arguments="{}"
... ),
... ).model_dump()
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
"""
class FunctionBody(BaseModel):
name: str
arguments: str | None
type: Literal["function"] = "function"
id: str
"""The ID of the tool call."""
function: FunctionBody
"""The function body of the tool call."""
class ToolCallPart(BaseModel):
"""A part of the tool call."""
arguments_part: str | None = None
"""A part of the arguments of the tool call."""
class Message(BaseModel):
"""A message in a conversation."""
role: Literal[
"system",
"user",
"assistant",
"tool",
]
content: str | list[ContentPart]
"""The content of the message."""
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
role: Literal["assistant"] = "assistant"
tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message):
"""A message segment representing a tool call."""
role: Literal["tool"] = "tool"
tool_call_id: str
class UserMessageSegment(Message):
"""A message segment from the user."""
role: Literal["user"] = "user"
class SystemMessageSegment(Message):
"""A message segment from the system."""
role: Literal["system"] = "system"

View File

@@ -1,9 +1,7 @@
import typing as T
from dataclasses import dataclass from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict): class AgentResponseData(T.TypedDict):
chain: MessageChain chain: MessageChain

View File

@@ -1,8 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Generic from typing import Any, Generic
from typing_extensions import TypeVar from typing_extensions import TypeVar
from astrbot.core.platform.astr_message_event import AstrMessageEvent
TContext = TypeVar("TContext", default=Any) TContext = TypeVar("TContext", default=Any)
@@ -11,7 +12,6 @@ class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state.""" """A context for running an agent, which can be used to pass additional data or state."""
context: TContext context: TContext
tool_call_timeout: int = 60 # Default tool call timeout in seconds event: AstrMessageEvent
NoContext = ContextWrapper[None] NoContext = ContextWrapper[None]

View File

@@ -1,15 +1,13 @@
import abc import abc
import typing as T import typing as T
from enum import Enum, auto from enum import Enum, auto
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponse
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
class AgentState(Enum): class AgentState(Enum):
"""Defines the state of the agent.""" """Defines the state of the agent."""
@@ -30,26 +28,31 @@ class BaseAgentRunner(T.Generic[TContext]):
agent_hooks: BaseAgentRunHooks[TContext], agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any, **kwargs: T.Any,
) -> None: ) -> None:
"""Reset the agent to its initial state. """
Reset the agent to its initial state.
This method should be called before starting a new run. This method should be called before starting a new run.
""" """
... ...
@abc.abstractmethod @abc.abstractmethod
async def step(self) -> T.AsyncGenerator[AgentResponse, None]: async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""Process a single step of the agent.""" """
Process a single step of the agent.
"""
... ...
@abc.abstractmethod @abc.abstractmethod
def done(self) -> bool: def done(self) -> bool:
"""Check if the agent has completed its task. """
Check if the agent has completed its task.
Returns True if the agent is done, False otherwise. Returns True if the agent is done, False otherwise.
""" """
... ...
@abc.abstractmethod @abc.abstractmethod
def get_final_llm_resp(self) -> LLMResponse | None: def get_final_llm_resp(self) -> LLMResponse | None:
"""Get the final observation from the agent. """
Get the final observation from the agent.
This method should be called after the agent is done. This method should be called after the agent is done.
""" """
... ...

View File

@@ -1,33 +1,31 @@
import sys import sys
import traceback import traceback
import typing as T import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentState
from mcp.types import ( from ..hooks import BaseAgentRunHooks
BlobResourceContents, from ..tool_executor import BaseFunctionToolExecutor
CallToolResult, from ..run_context import ContextWrapper, TContext
EmbeddedResource, from ..response import AgentResponseData
ImageContent, from astrbot.core.provider.provider import Provider
TextContent,
TextResourceContents,
)
from astrbot import logger
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
) )
from astrbot.core.provider.entities import ( from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest, ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult, ToolCallsResult,
) )
from astrbot.core.provider.provider import Provider from mcp.types import (
TextContent,
from ..hooks import BaseAgentRunHooks ImageContent,
from ..message import AssistantMessageSegment, ToolCallMessageSegment EmbeddedResource,
from ..response import AgentResponseData TextResourceContents,
from ..run_context import ContextWrapper, TContext BlobResourceContents,
from ..tool_executor import BaseFunctionToolExecutor CallToolResult,
from .base import AgentResponse, AgentState, BaseAgentRunner )
from astrbot import logger
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from typing import override from typing import override
@@ -72,7 +70,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override @override
async def step(self): async def step(self):
"""Process a single step of the agent. """
Process a single step of the agent.
This method should return the result of the step. This method should return the result of the step.
""" """
if not self.req: if not self.req:
@@ -100,7 +99,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse( yield AgentResponse(
type="streaming_delta", type="streaming_delta",
data=AgentResponseData( data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text), chain=MessageChain().message(llm_response.completion_text)
), ),
) )
continue continue
@@ -121,8 +120,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
type="err", type="err",
data=AgentResponseData( data=AgentResponseData(
chain=MessageChain().message( chain=MessageChain().message(
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}", f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
), )
), ),
) )
@@ -145,7 +144,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse( yield AgentResponse(
type="llm_result", type="llm_result",
data=AgentResponseData( data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text), chain=MessageChain().message(llm_resp.completion_text)
), ),
) )
@@ -156,7 +155,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield AgentResponse( yield AgentResponse(
type="tool_call", type="tool_call",
data=AgentResponseData( data=AgentResponseData(
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"), chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
), ),
) )
async for result in self._handle_function_tools(self.req, llm_resp): async for result in self._handle_function_tools(self.req, llm_resp):
@@ -170,7 +169,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 将结果添加到上下文中 # 将结果添加到上下文中
tool_calls_result = ToolCallsResult( tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment( tool_calls_info=AssistantMessageSegment(
tool_calls=llm_resp.to_openai_to_calls_model(), role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
content=llm_resp.completion_text, content=llm_resp.completion_text,
), ),
tool_calls_result=tool_call_result_blocks, tool_calls_result=tool_call_result_blocks,
@@ -198,50 +198,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
func_tool = req.func_tool.get_func(func_tool_name) func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}") logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
if not func_tool:
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: 未找到工具 {func_tool_name}",
),
)
continue
valid_params = {} # 参数过滤:只传递函数实际需要的参数
# 获取实际的 handler 函数
if func_tool.handler:
logger.debug(
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}",
)
if func_tool.parameters and func_tool.parameters.get("properties"):
expected_params = set(func_tool.parameters["properties"].keys())
valid_params = {
k: v
for k, v in func_tool_args.items()
if k in expected_params
}
# 记录被忽略的参数
ignored_params = set(func_tool_args.keys()) - set(
valid_params.keys(),
)
if ignored_params:
logger.warning(
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}",
)
else:
# 如果没有 handler如 MCP 工具),使用所有参数
valid_params = func_tool_args
try: try:
await self.agent_hooks.on_tool_start( await self.agent_hooks.on_tool_start(
self.run_context, self.run_context, func_tool, func_tool_args
func_tool,
valid_params,
) )
except Exception as e: except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True) logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
@@ -249,21 +208,18 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
executor = self.tool_executor.execute( executor = self.tool_executor.execute(
tool=func_tool, tool=func_tool,
run_context=self.run_context, run_context=self.run_context,
**valid_params, # 只传递有效的参数 **func_tool_args,
) )
async for resp in executor:
_final_resp: CallToolResult | None = None
async for resp in executor: # type: ignore
if isinstance(resp, CallToolResult): if isinstance(resp, CallToolResult):
res = resp res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent): if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content=res.content[0].text, content=res.content[0].text,
), )
) )
yield MessageChain().message(res.content[0].text) yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent): elif isinstance(res.content[0], ImageContent):
@@ -272,10 +228,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)", content="返回了图片(已直接发送给用户)",
), )
) )
yield MessageChain(type="tool_direct_result").base64_image( yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data, res.content[0].data
) )
elif isinstance(res.content[0], EmbeddedResource): elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource resource = res.content[0].resource
@@ -285,7 +241,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content=resource.text, content=resource.text,
), )
) )
yield MessageChain().message(resource.text) yield MessageChain().message(resource.text)
elif ( elif (
@@ -298,52 +254,72 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)", content="返回了图片(已直接发送给用户)",
), )
) )
yield MessageChain( yield MessageChain(
type="tool_direct_result", type="tool_direct_result"
).base64_image(resource.blob) ).base64_image(res.content[0].data)
else: else:
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content="返回的数据类型不受支持", content="返回的数据类型不受支持",
), )
) )
yield MessageChain().message("返回的数据类型不受支持。") yield MessageChain().message("返回的数据类型不受支持。")
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
elif resp is None: elif resp is None:
# Tool 直接请求发送消息给用户 # Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。 # 这里我们将直接结束 Agent Loop。
# 发送消息逻辑在 ToolExecutor 中处理了。
logger.warning(
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
)
self._transition_state(AgentState.DONE) self._transition_state(AgentState.DONE)
if res := self.run_context.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else: else:
# 不应该出现其他类型
logger.warning( logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。", f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
) )
try: try:
await self.agent_hooks.on_tool_end( await self.agent_hooks.on_tool_end(
self.run_context, self.run_context, func_tool_name, func_tool_args, None
func_tool, )
func_tool_args, except Exception as e:
_final_resp, logger.error(
) f"Error in on_tool_end hook: {e}", exc_info=True
except Exception as e: )
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
self.run_context.event.clear_result()
except Exception as e: except Exception as e:
logger.warning(traceback.format_exc()) logger.warning(traceback.format_exc())
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
role="tool", role="tool",
tool_call_id=func_tool_id, tool_call_id=func_tool_id,
content=f"error: {e!s}", content=f"error: {str(e)}",
), )
) )
# 处理函数调用响应 # 处理函数调用响应

View File

@@ -1,77 +1,57 @@
from collections.abc import Awaitable, Callable from dataclasses import dataclass
from typing import Any, Generic
import jsonschema
import mcp
from deprecated import deprecated from deprecated import deprecated
from pydantic import model_validator from typing import Awaitable, Literal, Any, Optional
from pydantic.dataclasses import dataclass from .mcp_client import MCPClient
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
@dataclass @dataclass
class ToolSchema: class FunctionTool:
"""A class representing the schema of a tool for function calling.""" """A class representing a function tool that can be used in function calling."""
name: str
"""The name of the tool."""
description: str
"""The description of the tool."""
parameters: ParametersType
"""The parameters of the tool, in JSON Schema format."""
@model_validator(mode="after")
def validate_parameters(self) -> "ToolSchema":
jsonschema.validate(
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
)
return self
@dataclass
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None
"""a callable that implements the tool's functionality. It should be an async function."""
name: str | None = None
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None handler_module_path: str | None = None
""" """处理函数的模块路径,当 origin 为 mcp 时,这个为空
The module path of the handler function. This is empty when the origin is mcp.
This field must be retained, as the handler will be wrapped in functools.partial during initialization, 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
causing the handler's __module__ to be functools
""" """
active: bool = True active: bool = True
""" """是否激活"""
Whether the tool is active. This field is a special field for AstrBot.
You can ignore it when integrating with other frameworks. origin: Literal["local", "mcp"] = "local"
""" """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str | None = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient | None = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self): def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def call( def __dict__(self) -> dict[str, Any]:
self, context: ContextWrapper[TContext], **kwargs """将 FunctionTool 转换为字典格式"""
) -> str | mcp.types.CallToolResult: return {
"""Run the tool with the given arguments. The handler field has priority.""" "name": self.name,
raise NotImplementedError( "parameters": self.parameters,
"FunctionTool.call() must be implemented by subclasses or set a handler." "description": self.description,
) "active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
class ToolSet: class ToolSet:
"""A set of function tools that can be used in function calling. """A set of function tools that can be used in function calling.
This class provides methods to add, remove, and retrieve tools, as well as This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI). convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
"""
def __init__(self, tools: list[FunctionTool] | None = None): def __init__(self, tools: list[FunctionTool] = None):
self.tools: list[FunctionTool] = tools or [] self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool: def empty(self) -> bool:
@@ -91,7 +71,7 @@ class ToolSet:
"""Remove a tool by its name.""" """Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name] self.tools = [tool for tool in self.tools if tool.name != name]
def get_tool(self, name: str) -> FunctionTool | None: def get_tool(self, name: str) -> Optional[FunctionTool]:
"""Get a tool by its name.""" """Get a tool by its name."""
for tool in self.tools: for tool in self.tools:
if tool.name == name: if tool.name == name:
@@ -99,13 +79,7 @@ class ToolSet:
return None return None
@deprecated(reason="Use add_tool() instead", version="4.0.0") @deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func( def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
"""Add a function tool to the set.""" """Add a function tool to the set."""
params = { params = {
"type": "object", # hard-coded here "type": "object", # hard-coded here
@@ -130,7 +104,7 @@ class ToolSet:
self.remove_tool(name) self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0") @deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> FunctionTool | None: def get_func(self, name: str) -> list[FunctionTool]:
"""Get all function tools.""" """Get all function tools."""
return self.get_tool(name) return self.get_tool(name)
@@ -151,9 +125,7 @@ class ToolSet:
}, },
} }
if ( if tool.parameters.get("properties") or not omit_empty_parameter_field:
tool.parameters and tool.parameters.get("properties")
) or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters func_def["function"]["parameters"] = tool.parameters
result.append(func_def) result.append(func_def)
@@ -163,14 +135,14 @@ class ToolSet:
"""Convert tools to Anthropic API format.""" """Convert tools to Anthropic API format."""
result = [] result = []
for tool in self.tools: for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = { tool_def = {
"name": tool.name, "name": tool.name,
"description": tool.description, "description": tool.description,
"input_schema": input_schema, "input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
} }
result.append(tool_def) result.append(tool_def)
return result return result
@@ -203,8 +175,7 @@ class ToolSet:
if "type" in schema and schema["type"] in supported_types: if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"] result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get( if "format" in schema and schema["format"] in supported_formats.get(
result["type"], result["type"], set()
set(),
): ):
result["format"] = schema["format"] result["format"] = schema["format"]
else: else:
@@ -239,15 +210,14 @@ class ToolSet:
return result return result
tools = [] tools = [
for tool in self.tools: {
d: dict[str, Any] = {
"name": tool.name, "name": tool.name,
"description": tool.description, "description": tool.description,
"parameters": convert_schema(tool.parameters),
} }
if tool.parameters: for tool in self.tools
d["parameters"] = convert_schema(tool.parameters) ]
tools.append(d)
declarations = {} declarations = {}
if tools: if tools:

View File

@@ -1,17 +1,11 @@
from collections.abc import AsyncGenerator
from typing import Any, Generic
import mcp import mcp
from typing import Any, Generic, AsyncGenerator
from .run_context import ContextWrapper, TContext from .run_context import TContext, ContextWrapper
from .tool import FunctionTool from .tool import FunctionTool
class BaseFunctionToolExecutor(Generic[TContext]): class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod @classmethod
async def execute( async def execute(
cls, cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
tool: FunctionTool,
run_context: ContextWrapper[TContext],
**tool_args,
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...

View File

@@ -1,6 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest from astrbot.core.provider.entities import ProviderRequest
@@ -11,4 +9,3 @@ class AstrAgentContext:
first_provider_request: ProviderRequest first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest curr_provider_request: ProviderRequest
streaming: bool streaming: bool
event: AstrMessageEvent

View File

@@ -1,14 +1,12 @@
import os import os
import uuid import uuid
from typing import TypedDict, TypeVar
from astrbot.core import AstrBotConfig, logger from astrbot.core import AstrBotConfig, logger
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession from astrbot.core.platform.message_session import MessageSession
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.utils.astrbot_path import get_astrbot_config_path from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from astrbot.core.utils.shared_preferences import SharedPreferences from typing import TypeVar, TypedDict
_VT = TypeVar("_VT") _VT = TypeVar("_VT")
@@ -17,12 +15,14 @@ class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform.""" """Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default" id: str # UUID of the configuration or "default"
umop: list[str] # Unified Message Origin Pattern
name: str name: str
path: str # File name to the configuration file path: str # File name to the configuration file
DEFAULT_CONFIG_CONF_INFO = ConfInfo( DEFAULT_CONFIG_CONF_INFO = ConfInfo(
id="default", id="default",
umop=["::"],
name="default", name="default",
path=ASTRBOT_CONFIG_PATH, path=ASTRBOT_CONFIG_PATH,
) )
@@ -31,35 +31,18 @@ DEFAULT_CONFIG_CONF_INFO = ConfInfo(
class AstrBotConfigManager: class AstrBotConfigManager:
"""A class to manage the system configuration of AstrBot, aka ACM""" """A class to manage the system configuration of AstrBot, aka ACM"""
def __init__( def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
self,
default_config: AstrBotConfig,
ucr: UmopConfigRouter,
sp: SharedPreferences,
):
self.sp = sp self.sp = sp
self.ucr = ucr
self.confs: dict[str, AstrBotConfig] = {} self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig""" """uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config self.confs["default"] = default_config
self.abconf_data = None
self._load_all_configs() self._load_all_configs()
def _get_abconf_data(self) -> dict:
"""获取所有的 abconf 数据"""
if self.abconf_data is None:
self.abconf_data = self.sp.get(
"abconf_mapping",
{},
scope="global",
scope_id="global",
)
return self.abconf_data
def _load_all_configs(self): def _load_all_configs(self):
"""Load all configurations from the shared preferences.""" """Load all configurations from the shared preferences."""
abconf_data = self._get_abconf_data() abconf_data = self.sp.get(
self.abconf_data = abconf_data "abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_data.items(): for uuid_, meta in abconf_data.items():
filename = meta["path"] filename = meta["path"]
conf_path = os.path.join(get_astrbot_config_path(), filename) conf_path = os.path.join(get_astrbot_config_path(), filename)
@@ -68,20 +51,30 @@ class AstrBotConfigManager:
self.confs[uuid_] = conf self.confs[uuid_] = conf
else: else:
logger.warning( logger.warning(
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping.", f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
) )
continue continue
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo: def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default") """获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
Returns: Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型 ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
""" """
# uuid -> { "path": str, "name": str } # uuid -> { "umop": list, "path": str, "name": str }
abconf_data = self._get_abconf_data() abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if isinstance(umo, MessageSession): if isinstance(umo, MessageSession):
umo = str(umo) umo = str(umo)
else: else:
@@ -90,13 +83,10 @@ class AstrBotConfigManager:
except Exception: except Exception:
return DEFAULT_CONFIG_CONF_INFO return DEFAULT_CONFIG_CONF_INFO
conf_id = self.ucr.get_conf_id_for_umop(umo) for uuid_, meta in abconf_data.items():
if conf_id: for pattern in meta["umop"]:
meta = abconf_data.get(conf_id) if self._is_umo_match(pattern, umo):
if meta and isinstance(meta, dict): return ConfInfo(**meta, id=uuid_)
# the bind relation between umo and conf is defined in ucr now, so we remove "umop" here
meta.pop("umop", None)
return ConfInfo(**meta, id=conf_id)
return DEFAULT_CONFIG_CONF_INFO return DEFAULT_CONFIG_CONF_INFO
@@ -104,22 +94,27 @@ class AstrBotConfigManager:
self, self,
abconf_path: str, abconf_path: str,
abconf_id: str, abconf_id: str,
umo_parts: list[str] | list[MessageSession],
abconf_name: str | None = None, abconf_name: str | None = None,
) -> None: ) -> None:
"""保存配置文件的映射关系""" """保存配置文件的映射关系"""
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data = self.sp.get( abconf_data = self.sp.get(
"abconf_mapping", "abconf_mapping", {}, scope="global", scope_id="global"
{},
scope="global",
scope_id="global",
) )
random_word = abconf_name or uuid.uuid4().hex[:8] random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = { abconf_data[abconf_id] = {
"umop": umo_parts,
"path": abconf_path, "path": abconf_path,
"name": random_word, "name": random_word,
} }
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig: def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。""" """获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
@@ -151,26 +146,31 @@ class AstrBotConfigManager:
def get_conf_list(self) -> list[ConfInfo]: def get_conf_list(self) -> list[ConfInfo]:
"""获取所有配置文件的元数据列表""" """获取所有配置文件的元数据列表"""
conf_list = [] conf_list = []
abconf_mapping = self._get_abconf_data()
for uuid_, meta in abconf_mapping.items():
if not isinstance(meta, dict):
continue
meta.pop("umop", None)
conf_list.append(ConfInfo(**meta, id=uuid_))
conf_list.append(DEFAULT_CONFIG_CONF_INFO) conf_list.append(DEFAULT_CONFIG_CONF_INFO)
abconf_mapping = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_mapping.items():
conf_list.append(ConfInfo(**meta, id=uuid_))
return conf_list return conf_list
def create_conf( def create_conf(
self, self,
umo_parts: list[str] | list[MessageSession],
config: dict = DEFAULT_CONFIG, config: dict = DEFAULT_CONFIG,
name: str | None = None, name: str | None = None,
) -> str: ) -> str:
"""
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
"""
conf_uuid = str(uuid.uuid4()) conf_uuid = str(uuid.uuid4())
conf_file_name = f"abconf_{conf_uuid}.json" conf_file_name = f"abconf_{conf_uuid}.json"
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name) conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
conf = AstrBotConfig(config_path=conf_path, default_config=config) conf = AstrBotConfig(config_path=conf_path, default_config=config)
conf.save_config() conf.save_config()
self._save_conf_mapping(conf_file_name, conf_uuid, abconf_name=name) self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
self.confs[conf_uuid] = conf self.confs[conf_uuid] = conf
return conf_uuid return conf_uuid
@@ -185,17 +185,13 @@ class AstrBotConfigManager:
Raises: Raises:
ValueError: 如果试图删除默认配置文件 ValueError: 如果试图删除默认配置文件
""" """
if conf_id == "default": if conf_id == "default":
raise ValueError("不能删除默认配置文件") raise ValueError("不能删除默认配置文件")
# 从映射中移除 # 从映射中移除
abconf_data = self.sp.get( abconf_data = self.sp.get(
"abconf_mapping", "abconf_mapping", {}, scope="global", scope_id="global"
{},
scope="global",
scope_id="global",
) )
if conf_id not in abconf_data: if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中") logger.warning(f"配置文件 {conf_id} 不存在于映射中")
@@ -203,8 +199,7 @@ class AstrBotConfigManager:
# 获取配置文件路径 # 获取配置文件路径
conf_path = os.path.join( conf_path = os.path.join(
get_astrbot_config_path(), get_astrbot_config_path(), abconf_data[conf_id]["path"]
abconf_data[conf_id]["path"],
) )
# 删除配置文件 # 删除配置文件
@@ -223,30 +218,28 @@ class AstrBotConfigManager:
# 从映射中移除 # 从映射中移除
del abconf_data[conf_id] del abconf_data[conf_id]
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功删除配置文件 {conf_id}") logger.info(f"成功删除配置文件 {conf_id}")
return True return True
def update_conf_info(self, conf_id: str, name: str | None = None) -> bool: def update_conf_info(
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
) -> bool:
"""更新配置文件信息 """更新配置文件信息
Args: Args:
conf_id: 配置文件的 UUID conf_id: 配置文件的 UUID
name: 新的配置文件名称 (可选) name: 新的配置文件名称 (可选)
umo_parts: 新的 UMO 部分列表 (可选)
Returns: Returns:
bool: 更新是否成功 bool: 更新是否成功
""" """
if conf_id == "default": if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息") raise ValueError("不能更新默认配置文件的信息")
abconf_data = self.sp.get( abconf_data = self.sp.get(
"abconf_mapping", "abconf_mapping", {}, scope="global", scope_id="global"
{},
scope="global",
scope_id="global",
) )
if conf_id not in abconf_data: if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中") logger.warning(f"配置文件 {conf_id} 不存在于映射中")
@@ -256,17 +249,25 @@ class AstrBotConfigManager:
if name is not None: if name is not None:
abconf_data[conf_id]["name"] = name abconf_data[conf_id]["name"] = name
# 更新 UMO 部分
if umo_parts is not None:
# 验证 UMO 部分格式
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data[conf_id]["umop"] = umo_parts
# 保存更新 # 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功更新配置文件 {conf_id} 的信息") logger.info(f"成功更新配置文件 {conf_id} 的信息")
return True return True
def g( def g(
self, self, umo: str | None = None, key: str | None = None, default: _VT = None
umo: str | None = None,
key: str | None = None,
default: _VT = None,
) -> _VT: ) -> _VT:
"""获取配置项。umo 为 None 时使用默认配置""" """获取配置项。umo 为 None 时使用默认配置"""
if umo is None: if umo is None:

View File

@@ -1,9 +1,9 @@
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
from .astrbot_config import * from .astrbot_config import *
from .default import DB_PATH, DEFAULT_CONFIG, VERSION
__all__ = [ __all__ = [
"DB_PATH",
"DEFAULT_CONFIG", "DEFAULT_CONFIG",
"VERSION", "VERSION",
"DB_PATH",
"AstrBotConfig", "AstrBotConfig",
] ]

View File

@@ -1,11 +1,10 @@
import enum import os
import json import json
import logging import logging
import os import enum
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
logger = logging.getLogger("astrbot") logger = logging.getLogger("astrbot")
@@ -28,7 +27,7 @@ class AstrBotConfig(dict):
self, self,
config_path: str = ASTRBOT_CONFIG_PATH, config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG, default_config: dict = DEFAULT_CONFIG,
schema: dict | None = None, schema: dict = None,
): ):
super().__init__() super().__init__()
@@ -46,7 +45,7 @@ class AstrBotConfig(dict):
json.dump(default_config, f, indent=4, ensure_ascii=False) json.dump(default_config, f, indent=4, ensure_ascii=False)
object.__setattr__(self, "first_deploy", True) # 标记第一次部署 object.__setattr__(self, "first_deploy", True) # 标记第一次部署
with open(config_path, encoding="utf-8-sig") as f: with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read() conf_str = f.read()
conf = json.loads(conf_str) conf = json.loads(conf_str)
@@ -66,7 +65,7 @@ class AstrBotConfig(dict):
for k, v in schema.items(): for k, v in schema.items():
if v["type"] not in DEFAULT_VALUE_MAP: if v["type"] not in DEFAULT_VALUE_MAP:
raise TypeError( raise TypeError(
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}", f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}"
) )
if "default" in v: if "default" in v:
default = v["default"] default = v["default"]
@@ -83,7 +82,7 @@ class AstrBotConfig(dict):
return conf return conf
def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" """检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
has_new = False has_new = False
@@ -98,28 +97,27 @@ class AstrBotConfig(dict):
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}") logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
new_conf[key] = value new_conf[key] = value
has_new = True has_new = True
elif conf[key] is None: else:
# 配置项为 None使用默认值 if conf[key] is None:
new_conf[key] = value # 配置项为 None使用默认值
has_new = True
elif isinstance(value, dict):
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
new_conf[key] = value new_conf[key] = value
has_new = True has_new = True
elif isinstance(value, dict):
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
new_conf[key] = value
has_new = True
else:
# 递归检查并同步顺序
child_has_new = self.check_config_integrity(
value, conf[key], path + "." + key if path else key
)
new_conf[key] = conf[key]
has_new |= child_has_new
else: else:
# 递归检查并同步顺序 # 直接使用现有配置
child_has_new = self.check_config_integrity(
value,
conf[key],
path + "." + key if path else key,
)
new_conf[key] = conf[key] new_conf[key] = conf[key]
has_new |= child_has_new
else:
# 直接使用现有配置
new_conf[key] = conf[key]
# 检查是否存在参考配置中没有的配置项 # 检查是否存在参考配置中没有的配置项
for key in list(conf.keys()): for key in list(conf.keys()):
@@ -142,7 +140,7 @@ class AstrBotConfig(dict):
return has_new return has_new
def save_config(self, replace_config: dict | None = None): def save_config(self, replace_config: Dict = None):
"""将配置写入文件 """将配置写入文件
如果传入 replace_config则将配置替换为 replace_config 如果传入 replace_config则将配置替换为 replace_config

View File

@@ -1,10 +1,12 @@
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。""" """
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
import os import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.5.3" VERSION = "4.0.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置 # 默认配置
@@ -49,28 +51,28 @@ DEFAULT_CONFIG = {
"enable": True, "enable": True,
"default_provider_id": "", "default_provider_id": "",
"default_image_caption_provider_id": "", "default_image_caption_provider_id": "",
"default_summarize_provider_id": "",
"context_exceed_calc_method": "token_size",
"max_token_size": 128000,
"max_context_length": 100,
"image_caption_prompt": "Please describe the image using Chinese.", "image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者 "provider_pool": ["*"], # "*" 表示使用所有可用的提供者
"wake_prefix": "", "wake_prefix": "",
"web_search": False, "web_search": False,
"websearch_provider": "default", "websearch_provider": "default",
"websearch_tavily_key": [], "websearch_tavily_key": "",
"websearch_baidu_app_builder_key": "",
"web_search_link": False, "web_search_link": False,
"display_reasoning_text": False, "display_reasoning_text": False,
"identifier": False, "identifier": False,
"group_name_display": False,
"datetime_system_prompt": True, "datetime_system_prompt": True,
"default_personality": "default", "default_personality": "default",
"persona_pool": ["*"], "persona_pool": ["*"],
"prompt_prefix": "{{prompt}}", "prompt_prefix": "",
"max_context_length": -1,
"dequeue_context_length": 1, "dequeue_context_length": 1,
"streaming_response": False, "streaming_response": False,
"show_tool_use_status": False, "show_tool_use_status": False,
"streaming_segmented": False, "streaming_segmented": False,
"max_agent_step": 30, "max_agent_step": 30,
"tool_call_timeout": 60,
}, },
"provider_stt_settings": { "provider_stt_settings": {
"enable": False, "enable": False,
@@ -104,7 +106,6 @@ DEFAULT_CONFIG = {
"t2i_strategy": "remote", "t2i_strategy": "remote",
"t2i_endpoint": "", "t2i_endpoint": "",
"t2i_use_file_service": False, "t2i_use_file_service": False,
"t2i_active_template": "base",
"http_proxy": "", "http_proxy": "",
"no_proxy": ["localhost", "127.0.0.1", "::1"], "no_proxy": ["localhost", "127.0.0.1", "::1"],
"dashboard": { "dashboard": {
@@ -116,15 +117,6 @@ DEFAULT_CONFIG = {
"port": 6185, "port": 6185,
}, },
"platform": [], "platform": [],
"platform_specific": {
# 平台特异配置:按平台分类,平台下按功能分组
"lark": {
"pre_ack_emoji": {"enable": False, "emojis": ["Typing"]},
},
"telegram": {
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
},
},
"wake_prefix": ["/"], "wake_prefix": ["/"],
"log_level": "INFO", "log_level": "INFO",
"pip_install_arg": "", "pip_install_arg": "",
@@ -132,11 +124,8 @@ DEFAULT_CONFIG = {
"persona": [], # deprecated "persona": [], # deprecated
"timezone": "Asia/Shanghai", "timezone": "Asia/Shanghai",
"callback_api_base": "", "callback_api_base": "",
"default_kb_collection": "", # 默认知识库名称, 已经过时 "default_kb_collection": "", # 默认知识库名称
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件 "plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
"kb_names": [], # 默认知识库名称列表
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
} }
@@ -163,11 +152,10 @@ CONFIG_METADATA_2 = {
"enable": False, "enable": False,
"appid": "", "appid": "",
"secret": "", "secret": "",
"is_sandbox": False,
"callback_server_host": "0.0.0.0", "callback_server_host": "0.0.0.0",
"port": 6196, "port": 6196,
}, },
"QQ 个人号(OneBot v11)": { "QQ 个人号(aiocqhttp)": {
"id": "default", "id": "default",
"type": "aiocqhttp", "type": "aiocqhttp",
"enable": False, "enable": False,
@@ -175,7 +163,7 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199, "ws_reverse_port": 6199,
"ws_reverse_token": "", "ws_reverse_token": "",
}, },
"WeChatPadPro": { "微信个人号(WeChatPadPro)": {
"id": "wechatpadpro", "id": "wechatpadpro",
"type": "wechatpadpro", "type": "wechatpadpro",
"enable": False, "enable": False,
@@ -211,18 +199,6 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0", "callback_server_host": "0.0.0.0",
"port": 6195, "port": 6195,
}, },
"企业微信智能机器人": {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"enable": True,
"wecomaibot_init_respond_text": "💭 思考中...",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"token": "",
"encoding_aes_key": "",
"callback_server_host": "0.0.0.0",
"port": 6198,
},
"飞书(Lark)": { "飞书(Lark)": {
"id": "lark", "id": "lark",
"type": "lark", "type": "lark",
@@ -261,24 +237,6 @@ CONFIG_METADATA_2 = {
"discord_guild_id_for_debug": "", "discord_guild_id_for_debug": "",
"discord_activity_name": "", "discord_activity_name": "",
}, },
"Misskey": {
"id": "misskey",
"type": "misskey",
"enable": False,
"misskey_instance_url": "https://misskey.example",
"misskey_token": "",
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
# download / security options
"misskey_allow_insecure_downloads": False,
"misskey_download_timeout": 15,
"misskey_download_chunk_size": 65536,
"misskey_max_download_bytes": None,
"misskey_enable_file_upload": True,
"misskey_upload_concurrency": 3,
"misskey_upload_folder": "",
},
"Slack": { "Slack": {
"id": "slack", "id": "slack",
"type": "slack", "type": "slack",
@@ -291,71 +249,8 @@ CONFIG_METADATA_2 = {
"slack_webhook_port": 6197, "slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback", "slack_webhook_path": "/astrbot-slack-webhook/callback",
}, },
"Satori": {
"id": "satori",
"type": "satori",
"enable": False,
"satori_api_base_url": "http://localhost:5140/satori/v1",
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
"satori_token": "",
"satori_auto_reconnect": True,
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
# "enable": False,
# "webchat_link_path": "",
# "webchat_present_type": "fullscreen",
# },
}, },
"items": { "items": {
# "webchat_link_path": {
# "description": "链接路径",
# "_special": "webchat_link_path",
# "type": "string",
# },
# "webchat_present_type": {
# "_special": "webchat_present_type",
# "description": "展现形式",
# "type": "string",
# "options": ["fullscreen", "embedded"],
# },
"is_sandbox": {
"description": "沙箱模式",
"type": "bool",
},
"satori_api_base_url": {
"description": "Satori API 终结点",
"type": "string",
"hint": "Satori API 的基础地址。",
},
"satori_endpoint": {
"description": "Satori WebSocket 终结点",
"type": "string",
"hint": "Satori 事件的 WebSocket 端点。",
},
"satori_token": {
"description": "Satori 令牌",
"type": "string",
"hint": "用于 Satori API 身份验证的令牌。",
},
"satori_auto_reconnect": {
"description": "启用自动重连",
"type": "bool",
"hint": "断开连接时是否自动重新连接 WebSocket。",
},
"satori_heartbeat_interval": {
"description": "Satori 心跳间隔",
"type": "int",
"hint": "发送心跳消息的间隔(秒)。",
},
"satori_reconnect_delay": {
"description": "Satori 重连延迟",
"type": "int",
"hint": "尝试重新连接前的延迟时间(秒)。",
},
"slack_connection_mode": { "slack_connection_mode": {
"description": "Slack Connection Mode", "description": "Slack Connection Mode",
"type": "string", "type": "string",
@@ -402,67 +297,6 @@ CONFIG_METADATA_2 = {
"type": "string", "type": "string",
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。", "hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
}, },
"misskey_instance_url": {
"description": "Misskey 实例 URL",
"type": "string",
"hint": "例如 https://misskey.example填写 Bot 账号所在的 Misskey 实例地址",
},
"misskey_token": {
"description": "Misskey Access Token",
"type": "string",
"hint": "连接服务设置生成的 API 鉴权访问令牌Access token",
},
"misskey_default_visibility": {
"description": "默认帖子可见性",
"type": "string",
"options": ["public", "home", "followers"],
"hint": "机器人发帖时的默认可见性设置。public公开home主页时间线followers仅关注者。",
},
"misskey_local_only": {
"description": "仅限本站(不参与联合)",
"type": "bool",
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
},
"misskey_enable_chat": {
"description": "启用聊天消息响应",
"type": "bool",
"hint": "启用后,机器人将会监听和响应私信聊天消息",
},
"misskey_enable_file_upload": {
"description": "启用文件上传到 Misskey",
"type": "bool",
"hint": "启用后,适配器会尝试将消息链中的文件上传到 Misskey。URL 文件会先尝试服务器端上传,异步上传失败时会回退到下载后本地上传。",
},
"misskey_allow_insecure_downloads": {
"description": "允许不安全下载(禁用 SSL 验证)",
"type": "bool",
"hint": "当远端服务器存在证书问题导致无法正常下载时,自动禁用 SSL 验证作为回退方案。适用于某些图床的证书配置问题。启用有安全风险,仅在必要时使用。",
},
"misskey_download_timeout": {
"description": "远端下载超时时间(秒)",
"type": "int",
"hint": "下载远程文件时的超时时间(秒),用于异步上传回退到本地上传的场景。",
},
"misskey_download_chunk_size": {
"description": "流式下载分块大小(字节)",
"type": "int",
"hint": "流式下载和计算 MD5 时使用的每次读取字节数,过小会增加开销,过大会占用内存。",
},
"misskey_max_download_bytes": {
"description": "最大允许下载字节数(超出则中止)",
"type": "int",
"hint": "如果希望限制下载文件的最大大小以防止 OOM请填写最大字节数留空或 null 表示不限制。",
},
"misskey_upload_concurrency": {
"description": "并发上传限制",
"type": "int",
"hint": "同时进行的文件上传任务上限(整数,默认 3",
},
"misskey_upload_folder": {
"description": "上传到网盘的目标文件夹 ID",
"type": "string",
"hint": "可选:填写 Misskey 网盘中目标文件夹的 ID上传的文件将放置到该文件夹内。留空则使用账号网盘根目录。",
},
"telegram_command_register": { "telegram_command_register": {
"description": "Telegram 命令注册", "description": "Telegram 命令注册",
"type": "bool", "type": "bool",
@@ -514,38 +348,24 @@ CONFIG_METADATA_2 = {
"hint": "启用后,机器人可以接收到频道的私聊消息。", "hint": "启用后,机器人可以接收到频道的私聊消息。",
}, },
"ws_reverse_host": { "ws_reverse_host": {
"description": "反向 Websocket 主机", "description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
"type": "string", "type": "string",
"hint": "AstrBot 将作为服务器端", "hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号",
}, },
"ws_reverse_port": { "ws_reverse_port": {
"description": "反向 Websocket 端口", "description": "反向 Websocket 端口",
"type": "int", "type": "int",
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
}, },
"ws_reverse_token": { "ws_reverse_token": {
"description": "反向 Websocket Token", "description": "反向 Websocket Token",
"type": "string", "type": "string",
"hint": "反向 Websocket Token。未设置则不启用 Token 验证。", "hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
},
"wecom_ai_bot_name": {
"description": "企业微信智能机器人的名字",
"type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。",
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。",
},
"wecomaibot_friend_message_welcome_text": {
"description": "企业微信智能机器人私聊欢迎语",
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
}, },
"lark_bot_name": { "lark_bot_name": {
"description": "飞书机器人的名字", "description": "飞书机器人的名字",
"type": "string", "type": "string",
"hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", "hint": "请务必填,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
}, },
"discord_token": { "discord_token": {
"description": "Discord Bot Token", "description": "Discord Bot Token",
@@ -740,7 +560,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.openai.com/v1", "api_base": "https://api.openai.com/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4}, "model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
"hint": "也兼容所有与 OpenAI API 兼容的服务。", "hint": "也兼容所有与 OpenAI API 兼容的服务。",
}, },
@@ -755,7 +574,6 @@ CONFIG_METADATA_2 = {
"api_base": "", "api_base": "",
"timeout": 120, "timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4}, "model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"xAI": { "xAI": {
@@ -768,8 +586,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.x.ai/v1", "api_base": "https://api.x.ai/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4}, "model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_extra_body": {},
"xai_native_search": False,
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"Anthropic": { "Anthropic": {
@@ -799,7 +615,6 @@ CONFIG_METADATA_2 = {
"key": ["ollama"], # ollama 的 key 默认是 ollama "key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1", "api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4}, "model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"LM Studio": { "LM Studio": {
@@ -813,7 +628,6 @@ CONFIG_METADATA_2 = {
"model_config": { "model_config": {
"model": "llama-3.1-8b", "model": "llama-3.1-8b",
}, },
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"Gemini(OpenAI兼容)": { "Gemini(OpenAI兼容)": {
@@ -829,7 +643,6 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash", "model": "gemini-1.5-flash",
"temperature": 0.4, "temperature": 0.4,
}, },
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"Gemini": { "Gemini": {
@@ -870,8 +683,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.deepseek.com/v1", "api_base": "https://api.deepseek.com/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4}, "model_config": {"model": "deepseek-chat", "temperature": 0.4},
"custom_extra_body": {}, "modalities": ["text", "image", "tool_use"],
"modalities": ["text", "tool_use"],
}, },
"302.AI": { "302.AI": {
"id": "302ai", "id": "302ai",
@@ -883,7 +695,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.302.ai/v1", "api_base": "https://api.302.ai/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4}, "model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"硅基流动": { "硅基流动": {
@@ -899,7 +710,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek-ai/DeepSeek-V3", "model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4, "temperature": 0.4,
}, },
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"PPIO派欧云": { "PPIO派欧云": {
@@ -915,22 +725,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek/deepseek-r1", "model": "deepseek/deepseek-r1",
"temperature": 0.4, "temperature": 0.4,
}, },
"custom_extra_body": {},
},
"小马算力": {
"id": "tokenpony",
"provider": "tokenpony",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.tokenpony.cn/v1",
"timeout": 120,
"model_config": {
"model": "kimi-k2-instruct-0905",
"temperature": 0.7,
},
"custom_extra_body": {},
}, },
"优云智算": { "优云智算": {
"id": "compshare", "id": "compshare",
@@ -944,7 +738,6 @@ CONFIG_METADATA_2 = {
"model_config": { "model_config": {
"model": "moonshotai/Kimi-K2-Instruct", "model": "moonshotai/Kimi-K2-Instruct",
}, },
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"Kimi": { "Kimi": {
@@ -957,7 +750,6 @@ CONFIG_METADATA_2 = {
"timeout": 120, "timeout": 120,
"api_base": "https://api.moonshot.cn/v1", "api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4}, "model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"智谱 AI": { "智谱 AI": {
@@ -989,18 +781,6 @@ CONFIG_METADATA_2 = {
"timeout": 60, "timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!", "hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
}, },
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
},
"阿里云百炼应用": { "阿里云百炼应用": {
"id": "dashscope", "id": "dashscope",
"provider": "dashscope", "provider": "dashscope",
@@ -1028,7 +808,6 @@ CONFIG_METADATA_2 = {
"timeout": 120, "timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1", "api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4}, "model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"FastGPT": { "FastGPT": {
@@ -1040,7 +819,6 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.fastgpt.in/api/v1", "api_base": "https://api.fastgpt.in/api/v1",
"timeout": 60, "timeout": 60,
"custom_extra_body": {},
}, },
"Whisper(API)": { "Whisper(API)": {
"id": "whisper", "id": "whisper",
@@ -1091,9 +869,6 @@ CONFIG_METADATA_2 = {
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": False, "enable": False,
"edge-tts-voice": "zh-CN-XiaoxiaoNeural", "edge-tts-voice": "zh-CN-XiaoxiaoNeural",
"rate": "+0%",
"volume": "+0%",
"pitch": "+0Hz",
"timeout": 20, "timeout": 20,
}, },
"GSV TTS(本地加载)": { "GSV TTS(本地加载)": {
@@ -1152,7 +927,6 @@ CONFIG_METADATA_2 = {
"timeout": "20", "timeout": "20",
}, },
"阿里云百炼 TTS(API)": { "阿里云百炼 TTS(API)": {
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
"id": "dashscope_tts", "id": "dashscope_tts",
"provider": "dashscope", "provider": "dashscope",
"type": "dashscope_tts", "type": "dashscope_tts",
@@ -1261,38 +1035,8 @@ CONFIG_METADATA_2 = {
"rerank_model": "BAAI/bge-reranker-base", "rerank_model": "BAAI/bge-reranker-base",
"timeout": 20, "timeout": 20,
}, },
"Xinference Rerank": {
"id": "xinference_rerank",
"type": "xinference_rerank",
"provider": "xinference",
"provider_type": "rerank",
"enable": True,
"rerank_api_key": "",
"rerank_api_base": "http://127.0.0.1:9997",
"rerank_model": "BAAI/bge-reranker-base",
"timeout": 20,
"launch_model_if_not_running": False,
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
"provider": "xinference",
"provider_type": "speech_to_text",
"enable": False,
"api_key": "",
"api_base": "http://127.0.0.1:9997",
"model": "whisper-large-v3",
"timeout": 180,
"launch_model_if_not_running": False,
},
}, },
"items": { "items": {
"xai_native_search": {
"description": "启用原生搜索功能",
"type": "bool",
"hint": "启用后,将通过 xAI 的 Chat Completions 原生 Live Search 进行联网检索(按需计费)。仅对 xAI 提供商生效。",
"condition": {"provider": "xai"},
},
"rerank_api_base": { "rerank_api_base": {
"description": "重排序模型 API Base URL", "description": "重排序模型 API Base URL",
"type": "string", "type": "string",
@@ -1307,11 +1051,6 @@ CONFIG_METADATA_2 = {
"description": "重排序模型名称", "description": "重排序模型名称",
"type": "string", "type": "string",
}, },
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"type": "bool",
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
},
"modalities": { "modalities": {
"description": "模型能力", "description": "模型能力",
"type": "list", "type": "list",
@@ -1321,12 +1060,6 @@ CONFIG_METADATA_2 = {
"render_type": "checkbox", "render_type": "checkbox",
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。", "hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
}, },
"custom_extra_body": {
"description": "自定义请求体参数",
"type": "dict",
"items": {},
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
},
"provider": { "provider": {
"type": "string", "type": "string",
"invisible": True, "invisible": True,
@@ -1455,7 +1188,6 @@ CONFIG_METADATA_2 = {
"description": "嵌入维度", "description": "嵌入维度",
"type": "int", "type": "int",
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。", "hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
"_special": "get_embedding_dim",
}, },
"embedding_model": { "embedding_model": {
"description": "嵌入模型", "description": "嵌入模型",
@@ -1568,7 +1300,11 @@ CONFIG_METADATA_2 = {
"description": "服务订阅密钥", "description": "服务订阅密钥",
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)", "hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
}, },
"dashscope_tts_voice": {"description": "音色", "type": "string"}, "dashscope_tts_voice": {
"description": "语音合成模型",
"type": "string",
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
},
"gm_resp_image_modal": { "gm_resp_image_modal": {
"description": "启用图片模态", "description": "启用图片模态",
"type": "bool", "type": "bool",
@@ -1900,26 +1636,6 @@ CONFIG_METADATA_2 = {
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。", "hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
"obvious": True, "obvious": True,
}, },
"coze_api_key": {
"description": "Coze API Key",
"type": "string",
"hint": "Coze API 密钥,用于访问 Coze 服务。",
},
"bot_id": {
"description": "Bot ID",
"type": "string",
"hint": "Coze 机器人的 ID在 Coze 平台上创建机器人后获得。",
},
"coze_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
},
"auto_save_history": {
"description": "由 Coze 管理对话记录",
"type": "bool",
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
},
}, },
}, },
"provider_settings": { "provider_settings": {
@@ -1946,9 +1662,6 @@ CONFIG_METADATA_2 = {
"identifier": { "identifier": {
"type": "bool", "type": "bool",
}, },
"group_name_display": {
"type": "bool",
},
"datetime_system_prompt": { "datetime_system_prompt": {
"type": "bool", "type": "bool",
}, },
@@ -1977,10 +1690,6 @@ CONFIG_METADATA_2 = {
"description": "工具调用轮数上限", "description": "工具调用轮数上限",
"type": "int", "type": "int",
}, },
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
}, },
}, },
"provider_stt_settings": { "provider_stt_settings": {
@@ -2103,9 +1812,6 @@ CONFIG_METADATA_2 = {
"default_kb_collection": { "default_kb_collection": {
"type": "string", "type": "string",
}, },
"kb_names": {"type": "list", "items": {"type": "string"}},
"kb_fusion_top_k": {"type": "int", "default": 20},
"kb_final_top_k": {"type": "int", "default": 5},
}, },
}, },
} }
@@ -2129,39 +1835,51 @@ CONFIG_METADATA_3 = {
"_special": "select_provider", "_special": "select_provider",
"hint": "留空时使用第一个模型。", "hint": "留空时使用第一个模型。",
}, },
"provider_settings.default_summarize_provider_id": {
"description": "默认对话总结模型",
"type": "string",
"_special": "select_provider",
"hint": "留空代表不进行对话总结。可用于压缩上下文以减少 token 用量,并一定程度上保持历史聊天记忆。",
},
"provider_settings.default_image_caption_provider_id": { "provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型", "description": "默认图片转述模型",
"type": "string", "type": "string",
"_special": "select_provider", "_special": "select_provider",
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。", "hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
}, },
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",
"hint": "STT 总开关。",
},
"provider_stt_settings.provider_id": { "provider_stt_settings.provider_id": {
"description": "默认语音转文本模型", "description": "语音转文本模型",
"type": "string", "type": "string",
"hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型", "hint": "留空代表不使用",
"_special": "select_provider_stt", "_special": "select_provider_stt",
"condition": {
"provider_stt_settings.enable": True,
},
},
"provider_tts_settings.enable": {
"description": "启用文本转语音",
"type": "bool",
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
}, },
"provider_tts_settings.provider_id": { "provider_tts_settings.provider_id": {
"description": "默认文本转语音模型", "description": "文本转语音模型",
"type": "string", "type": "string",
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型", "hint": "留空代表不使用",
"_special": "select_provider_tts", "_special": "select_provider_tts",
},
"provider_settings.context_exceed_calc_method": {
"description": "上下文超限的触发策略",
"type": "string",
"options": ["token_size", "context_length"],
"labels": ["基于 Token 长度(估算)", "基于对话轮数"],
"hint": "如配置了对话总结模型,则触发时总结对话内容,否则丢弃最旧部分。"
},
"provider_settings.max_context_length": {
"description": "对话轮数上限",
"type": "int",
"condition": { "condition": {
"provider_tts_settings.enable": True, "provider_settings.context_exceed_calc_method": "context_length"
}, }
},
"provider_settings.max_token_size": {
"description": "Token 长度上限(估算)",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分。",
"condition": {
"provider_settings.context_exceed_calc_method": "token_size"
}
}, },
"provider_settings.image_caption_prompt": { "provider_settings.image_caption_prompt": {
"description": "图片转述提示词", "description": "图片转述提示词",
@@ -2184,22 +1902,10 @@ CONFIG_METADATA_3 = {
"description": "知识库", "description": "知识库",
"type": "object", "type": "object",
"items": { "items": {
"kb_names": { "default_kb_collection": {
"description": "知识库列表", "description": "默认使用的知识库",
"type": "list", "type": "string",
"items": {"type": "string"},
"_special": "select_knowledgebase", "_special": "select_knowledgebase",
"hint": "支持多选",
},
"kb_fusion_top_k": {
"description": "融合检索结果数",
"type": "int",
"hint": "多个知识库检索结果融合后的返回结果数量",
},
"kb_final_top_k": {
"description": "最终返回结果数",
"type": "int",
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
}, },
}, },
}, },
@@ -2214,25 +1920,15 @@ CONFIG_METADATA_3 = {
"provider_settings.websearch_provider": { "provider_settings.websearch_provider": {
"description": "网页搜索提供商", "description": "网页搜索提供商",
"type": "string", "type": "string",
"options": ["default", "tavily", "baidu_ai_search"], "options": ["default", "tavily"],
}, },
"provider_settings.websearch_tavily_key": { "provider_settings.websearch_tavily_key": {
"description": "Tavily API Key", "description": "Tavily API Key",
"type": "list", "type": "string",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"condition": { "condition": {
"provider_settings.websearch_provider": "tavily", "provider_settings.websearch_provider": "tavily",
}, },
}, },
"provider_settings.websearch_baidu_app_builder_key": {
"description": "百度千帆智能云 APP Builder API Key",
"type": "string",
"hint": "参考https://console.bce.baidu.com/iam/#/iam/apikey/list",
"condition": {
"provider_settings.websearch_provider": "baidu_ai_search",
},
},
"provider_settings.web_search_link": { "provider_settings.web_search_link": {
"description": "显示来源引用", "description": "显示来源引用",
"type": "bool", "type": "bool",
@@ -2248,14 +1944,9 @@ CONFIG_METADATA_3 = {
"type": "bool", "type": "bool",
}, },
"provider_settings.identifier": { "provider_settings.identifier": {
"description": "用户识别", "description": "用户感知",
"type": "bool", "type": "bool",
}, },
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
},
"provider_settings.datetime_system_prompt": { "provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知", "description": "现实世界时间感知",
"type": "bool", "type": "bool",
@@ -2268,10 +1959,6 @@ CONFIG_METADATA_3 = {
"description": "工具调用轮数上限", "description": "工具调用轮数上限",
"type": "int", "type": "int",
}, },
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
"provider_settings.streaming_response": { "provider_settings.streaming_response": {
"description": "流式回复", "description": "流式回复",
"type": "bool", "type": "bool",
@@ -2280,11 +1967,6 @@ CONFIG_METADATA_3 = {
"description": "不支持流式回复的平台采取分段输出", "description": "不支持流式回复的平台采取分段输出",
"type": "bool", "type": "bool",
}, },
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。",
},
"provider_settings.dequeue_context_length": { "provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数", "description": "丢弃对话轮数",
"type": "int", "type": "int",
@@ -2293,14 +1975,12 @@ CONFIG_METADATA_3 = {
"provider_settings.wake_prefix": { "provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ", "description": "LLM 聊天额外唤醒前缀 ",
"type": "string", "type": "string",
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
}, },
"provider_settings.prompt_prefix": { "provider_settings.prompt_prefix": {
"description": "用户提示词", "description": "额外前缀提示词",
"type": "string", "type": "string",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
}, },
"provider_tts_settings.dual_output": { "provider_settings.dual_output": {
"description": "开启 TTS 时同时输出语音和文字内容", "description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool", "type": "bool",
}, },
@@ -2409,41 +2089,41 @@ CONFIG_METADATA_3 = {
"description": "内容安全", "description": "内容安全",
"type": "object", "type": "object",
"items": { "items": {
"content_safety.also_use_in_response": { "platform_settings.content_safety.also_use_in_response": {
"description": "同时检查模型的响应内容", "description": "同时检查模型的响应内容",
"type": "bool", "type": "bool",
}, },
"content_safety.baidu_aip.enable": { "platform_settings.content_safety.baidu_aip.enable": {
"description": "使用百度内容安全审核", "description": "使用百度内容安全审核",
"type": "bool", "type": "bool",
"hint": "您需要手动安装 baidu-aip 库。", "hint": "您需要手动安装 baidu-aip 库。",
}, },
"content_safety.baidu_aip.app_id": { "platform_settings.content_safety.baidu_aip.app_id": {
"description": "App ID", "description": "App ID",
"type": "string", "type": "string",
"condition": { "condition": {
"content_safety.baidu_aip.enable": True, "platform_settings.content_safety.baidu_aip.enable": True,
}, },
}, },
"content_safety.baidu_aip.api_key": { "platform_settings.content_safety.baidu_aip.api_key": {
"description": "API Key", "description": "API Key",
"type": "string", "type": "string",
"condition": { "condition": {
"content_safety.baidu_aip.enable": True, "platform_settings.content_safety.baidu_aip.enable": True,
}, },
}, },
"content_safety.baidu_aip.secret_key": { "platform_settings.content_safety.baidu_aip.secret_key": {
"description": "Secret Key", "description": "Secret Key",
"type": "string", "type": "string",
"condition": { "condition": {
"content_safety.baidu_aip.enable": True, "platform_settings.content_safety.baidu_aip.enable": True,
}, },
}, },
"content_safety.internal_keywords.enable": { "platform_settings.content_safety.internal_keywords.enable": {
"description": "关键词检查", "description": "关键词检查",
"type": "bool", "type": "bool",
}, },
"content_safety.internal_keywords.extra_keywords": { "platform_settings.content_safety.internal_keywords.extra_keywords": {
"description": "额外关键词", "description": "额外关键词",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
@@ -2481,32 +2161,6 @@ CONFIG_METADATA_3 = {
"description": "用户权限不足时是否回复", "description": "用户权限不足时是否回复",
"type": "bool", "type": "bool",
}, },
"platform_specific.lark.pre_ack_emoji.enable": {
"description": "[飞书] 启用预回应表情",
"type": "bool",
},
"platform_specific.lark.pre_ack_emoji.emojis": {
"description": "表情列表(飞书表情枚举名)",
"type": "list",
"items": {"type": "string"},
"hint": "表情枚举名参考https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce",
"condition": {
"platform_specific.lark.pre_ack_emoji.enable": True,
},
},
"platform_specific.telegram.pre_ack_emoji.enable": {
"description": "[Telegram] 启用预回应表情",
"type": "bool",
},
"platform_specific.telegram.pre_ack_emoji.emojis": {
"description": "表情列表Unicode",
"type": "list",
"items": {"type": "string"},
"hint": "Telegram 仅支持固定反应集合参考https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9",
"condition": {
"platform_specific.telegram.pre_ack_emoji.enable": True,
},
},
}, },
}, },
}, },
@@ -2660,13 +2314,7 @@ CONFIG_METADATA_3_SYSTEM = {
"condition": { "condition": {
"t2i_strategy": "remote", "t2i_strategy": "remote",
}, },
"_special": "t2i_template", "_special": "t2i_template"
},
"t2i_active_template": {
"description": "当前应用的文转图渲染模板",
"type": "string",
"hint": "此处的值由文转图模板管理页面进行维护。",
"invisible": True,
}, },
"log_level": { "log_level": {
"description": "控制台日志级别", "description": "控制台日志级别",
@@ -2699,15 +2347,10 @@ CONFIG_METADATA_3_SYSTEM = {
"type": "string", "type": "string",
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
}, },
"no_proxy": {
"description": "直连地址列表",
"type": "list",
"items": {"type": "string"},
},
}, },
}, }
}, },
}, }
} }

View File

@@ -1,14 +1,13 @@
"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库. """
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话, 在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除 在一个会话中可以建立多个对话, 并且支持对话的切换和删除
""" """
import json import json
from collections.abc import Awaitable, Callable
from astrbot.core import sp from astrbot.core import sp
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment from typing import Dict, List
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2 from astrbot.core.db.po import Conversation, ConversationV2
@@ -17,45 +16,10 @@ class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase): def __init__(self, db_helper: BaseDatabase):
self.session_conversations: dict[str, str] = {} self.session_conversations: Dict[str, str] = {}
self.db = db_helper self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次 self.save_interval = 60 # 每 60 秒保存一次
# 会话删除回调函数列表(用于级联清理,如知识库配置)
self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []
def register_on_session_deleted(
self,
callback: Callable[[str], Awaitable[None]],
) -> None:
"""注册会话删除回调函数.
其他模块可以注册回调来响应会话删除事件,实现级联清理。
例如:知识库模块可以注册回调来清理会话的知识库配置。
Args:
callback: 回调函数接收会话ID (unified_msg_origin) 作为参数
"""
self._on_session_deleted_callbacks.append(callback)
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
"""触发会话删除回调.
Args:
unified_msg_origin: 会话ID
"""
for callback in self._on_session_deleted_callbacks:
try:
await callback(unified_msg_origin)
except Exception as e:
from astrbot.core import logger
logger.error(
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象""" """将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp()) created_at = int(conv_v2.created_at.timestamp())
@@ -79,13 +43,12 @@ class ConversationManager:
title: str | None = None, title: str | None = None,
persona_id: str | None = None, persona_id: str | None = None,
) -> str: ) -> str:
"""新建对话,并将当前会话的对话转移到新对话. """新建对话,并将当前会话的对话转移到新对话
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns: Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串 conversation_id (str): 对话 ID, 是 uuid 格式的字符串
""" """
if not platform_id: if not platform_id:
# 如果没有提供 platform_id则从 unified_msg_origin 中解析 # 如果没有提供 platform_id则从 unified_msg_origin 中解析
@@ -111,46 +74,30 @@ class ConversationManager:
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串 conversation_id (str): 对话 ID, 是 uuid 格式的字符串
""" """
self.session_conversations[unified_msg_origin] = conversation_id self.session_conversations[unified_msg_origin] = conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id) await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
async def delete_conversation( async def delete_conversation(
self, self, unified_msg_origin: str, conversation_id: str | None = None
unified_msg_origin: str,
conversation_id: str | None = None,
): ):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串 conversation_id (str): 对话 ID, 是 uuid 格式的字符串
""" """
f = False
if not conversation_id: if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin) conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
if conversation_id: if conversation_id:
await self.db.delete_conversation(cid=conversation_id) await self.db.delete_conversation(cid=conversation_id)
curr_cid = await self.get_curr_conversation_id(unified_msg_origin) if f:
if curr_cid == conversation_id:
self.session_conversations.pop(unified_msg_origin, None) self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id") await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
# 触发会话删除回调(级联清理)
await self._trigger_session_deleted(unified_msg_origin)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None: async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID """获取会话当前的对话 ID
@@ -158,7 +105,6 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns: Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串 conversation_id (str): 对话 ID, 是 uuid 格式的字符串
""" """
ret = self.session_conversations.get(unified_msg_origin, None) ret = self.session_conversations.get(unified_msg_origin, None)
if not ret: if not ret:
@@ -173,15 +119,13 @@ class ConversationManager:
conversation_id: str, conversation_id: str,
create_if_not_exists: bool = False, create_if_not_exists: bool = False,
) -> Conversation | None: ) -> Conversation | None:
"""获取会话的对话. """获取会话的对话
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串 conversation_id (str): 对话 ID, 是 uuid 格式的字符串
create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
Returns: Returns:
conversation (Conversation): 对话对象 conversation (Conversation): 对话对象
""" """
conv = await self.db.get_conversation_by_id(cid=conversation_id) conv = await self.db.get_conversation_by_id(cid=conversation_id)
if not conv and create_if_not_exists: if not conv and create_if_not_exists:
@@ -194,22 +138,18 @@ class ConversationManager:
return conv_res return conv_res
async def get_conversations( async def get_conversations(
self, self, unified_msg_origin: str | None = None, platform_id: str | None = None
unified_msg_origin: str | None = None, ) -> List[Conversation]:
platform_id: str | None = None, """获取对话列表
) -> list[Conversation]:
"""获取对话列表.
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id可选 unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id可选
platform_id (str): 平台 ID, 可选参数, 用于过滤对话 platform_id (str): 平台 ID, 可选参数, 用于过滤对话
Returns: Returns:
conversations (List[Conversation]): 对话对象列表 conversations (List[Conversation]): 对话对象列表
""" """
convs = await self.db.get_conversations( convs = await self.db.get_conversations(
user_id=unified_msg_origin, user_id=unified_msg_origin, platform_id=platform_id
platform_id=platform_id,
) )
convs_res = [] convs_res = []
for conv in convs: for conv in convs:
@@ -225,7 +165,7 @@ class ConversationManager:
search_query: str = "", search_query: str = "",
**kwargs, **kwargs,
) -> tuple[list[Conversation], int]: ) -> tuple[list[Conversation], int]:
"""获取过滤后的对话列表. """获取过滤后的对话列表
Args: Args:
page (int): 页码, 默认为 1 page (int): 页码, 默认为 1
@@ -234,7 +174,6 @@ class ConversationManager:
search_query (str): 搜索查询字符串, 可选 search_query (str): 搜索查询字符串, 可选
Returns: Returns:
conversations (list[Conversation]): 对话对象列表 conversations (list[Conversation]): 对话对象列表
""" """
convs, cnt = await self.db.get_filtered_conversations( convs, cnt = await self.db.get_filtered_conversations(
page=page, page=page,
@@ -256,14 +195,13 @@ class ConversationManager:
history: list[dict] | None = None, history: list[dict] | None = None,
title: str | None = None, title: str | None = None,
persona_id: str | None = None, persona_id: str | None = None,
) -> None: ):
"""更新会话的对话. """更新会话的对话
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串 conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
""" """
if not conversation_id: if not conversation_id:
# 如果没有提供 conversation_id则获取当前的 # 如果没有提供 conversation_id则获取当前的
@@ -277,20 +215,16 @@ class ConversationManager:
) )
async def update_conversation_title( async def update_conversation_title(
self, self, unified_msg_origin: str, title: str, conversation_id: str | None = None
unified_msg_origin: str, ):
title: str, """更新会话的对话标题
conversation_id: str | None = None,
) -> None:
"""更新会话的对话标题.
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题 title (str): 对话标题
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Deprecated: Deprecated:
Use `update_conversation` with `title` parameter instead. Use `update_conversation` with `title` parameter instead.
""" """
await self.update_conversation( await self.update_conversation(
unified_msg_origin=unified_msg_origin, unified_msg_origin=unified_msg_origin,
@@ -303,16 +237,15 @@ class ConversationManager:
unified_msg_origin: str, unified_msg_origin: str,
persona_id: str, persona_id: str,
conversation_id: str | None = None, conversation_id: str | None = None,
) -> None: ):
"""更新会话的对话 Persona ID. """更新会话的对话 Persona ID
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID persona_id (str): 对话 Persona ID
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Deprecated: Deprecated:
Use `update_conversation` with `persona_id` parameter instead. Use `update_conversation` with `persona_id` parameter instead.
""" """
await self.update_conversation( await self.update_conversation(
unified_msg_origin=unified_msg_origin, unified_msg_origin=unified_msg_origin,
@@ -320,85 +253,40 @@ class ConversationManager:
persona_id=persona_id, persona_id=persona_id,
) )
async def add_message_pair(
self,
cid: str,
user_message: UserMessageSegment | dict,
assistant_message: AssistantMessageSegment | dict,
) -> None:
"""Add a user-assistant message pair to the conversation history.
Args:
cid (str): Conversation ID
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
Raises:
Exception: If the conversation with the given ID is not found
"""
conv = await self.db.get_conversation_by_id(cid=cid)
if not conv:
raise Exception(f"Conversation with id {cid} not found")
history = conv.content or []
if isinstance(user_message, UserMessageSegment):
user_msg_dict = user_message.model_dump()
else:
user_msg_dict = user_message
if isinstance(assistant_message, AssistantMessageSegment):
assistant_msg_dict = assistant_message.model_dump()
else:
assistant_msg_dict = assistant_message
history.append(user_msg_dict)
history.append(assistant_msg_dict)
await self.db.update_conversation(
cid=cid,
content=history,
)
async def get_human_readable_context( async def get_human_readable_context(
self, self, unified_msg_origin, conversation_id, page=1, page_size=10
unified_msg_origin: str, ):
conversation_id: str, """获取人类可读的上下文
page: int = 1,
page_size: int = 10,
) -> tuple[list[str], int]:
"""获取人类可读的上下文.
Args: Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串 conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码 page (int): 页码
page_size (int): 每页大小 page_size (int): 每页大小
""" """
conversation = await self.get_conversation(unified_msg_origin, conversation_id) conversation = await self.get_conversation(unified_msg_origin, conversation_id)
if not conversation:
return [], 0
history = json.loads(conversation.history) history = json.loads(conversation.history)
# contexts_groups 存放按顺序的段落(每个段落是一个 str 列表), contexts = []
# 之后会被展平成一个扁平的 str 列表返回。 temp_contexts = []
contexts_groups: list[list[str]] = []
temp_contexts: list[str] = []
for record in history: for record in history:
if record["role"] == "user": if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}") temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant": elif record["role"] == "assistant":
if record.get("content"): if "content" in record and record["content"]:
temp_contexts.append(f"Assistant: {record['content']}") temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record: elif "tool_calls" in record:
tool_calls_str = json.dumps( tool_calls_str = json.dumps(
record["tool_calls"], record["tool_calls"], ensure_ascii=False
ensure_ascii=False,
) )
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}") temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else: else:
temp_contexts.append("Assistant: [未知的内容]") temp_contexts.append("Assistant: [未知的内容]")
contexts_groups.insert(0, temp_contexts) contexts.insert(0, temp_contexts)
temp_contexts = [] temp_contexts = []
# 展平分组后的 contexts 列表为单层字符串列表 # 展平 contexts 列表
contexts = [item for sublist in contexts_groups for item in sublist] contexts = [item for sublist in contexts for item in sublist]
# 计算分页 # 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size] paged_contexts = contexts[(page - 1) * page_size : page * page_size]

View File

@@ -1,5 +1,5 @@
"""Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作. """
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。 该类还负责加载和执行插件, 以及处理事件总线的分发。
@@ -9,44 +9,42 @@
3. 执行启动完成事件钩子 3. 执行启动完成事件钩子
""" """
import asyncio
import os
import threading
import time
import traceback import traceback
import asyncio
import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config, html_renderer
from asyncio import Queue from asyncio import Queue
from typing import List
from astrbot.core import LogBroker, logger, sp from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, sp
from astrbot.core.config.default import VERSION from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.provider.manager import ProviderManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star import PluginManager from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.context import Context from astrbot.core.star.star_handler import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from . import astrbot_config, html_renderer
from .event_bus import EventBus
class AstrBotCoreLifecycle: class AstrBotCoreLifecycle:
"""AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作. """
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。 EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。 该类还负责加载和执行插件, 以及处理事件总线的分发。
""" """
def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None: def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker # 初始化日志代理 self.log_broker = log_broker # 初始化日志代理
self.astrbot_config = astrbot_config # 初始化配置 self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库 self.db = db # 初始化数据库
@@ -70,11 +68,11 @@ class AstrBotCoreLifecycle:
del os.environ["no_proxy"] del os.environ["no_proxy"]
logger.debug("HTTP proxy cleared") logger.debug("HTTP proxy cleared")
async def initialize(self) -> None: async def initialize(self):
"""初始化 AstrBot 核心生命周期管理类.
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
""" """
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理 # 初始化日志代理
logger.info("AstrBot v" + VERSION) logger.info("AstrBot v" + VERSION)
if os.environ.get("TESTING", ""): if os.environ.get("TESTING", ""):
@@ -86,23 +84,11 @@ class AstrBotCoreLifecycle:
await html_renderer.initialize() await html_renderer.initialize()
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
# 初始化 AstrBot 配置管理器 # 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager( self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config, default_config=self.astrbot_config, sp=sp
ucr=self.umop_config_router,
sp=sp,
) )
# 4.5 to 4.6 migration for umop_config_router
try:
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
except Exception as e:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列 # 初始化事件队列
self.event_queue = Queue() self.event_queue = Queue()
@@ -112,9 +98,7 @@ class AstrBotCoreLifecycle:
# 初始化供应商管理器 # 初始化供应商管理器
self.provider_manager = ProviderManager( self.provider_manager = ProviderManager(
self.astrbot_config_mgr, self.astrbot_config_mgr, self.db, self.persona_mgr
self.db,
self.persona_mgr,
) )
# 初始化平台管理器 # 初始化平台管理器
@@ -126,9 +110,6 @@ class AstrBotCoreLifecycle:
# 初始化平台消息历史管理器 # 初始化平台消息历史管理器
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db) self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
# 初始化提供给插件的上下文 # 初始化提供给插件的上下文
self.star_context = Context( self.star_context = Context(
self.event_queue, self.event_queue,
@@ -140,7 +121,6 @@ class AstrBotCoreLifecycle:
self.platform_message_history_manager, self.platform_message_history_manager,
self.persona_mgr, self.persona_mgr,
self.astrbot_config_mgr, self.astrbot_config_mgr,
self.kb_manager,
) )
# 初始化插件管理器 # 初始化插件管理器
@@ -152,9 +132,8 @@ class AstrBotCoreLifecycle:
# 根据配置实例化各个 Provider # 根据配置实例化各个 Provider
await self.provider_manager.initialize() await self.provider_manager.initialize()
await self.kb_manager.initialize()
# 初始化消息事件流水线调度器 # 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler() self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
# 初始化更新器 # 初始化更新器
@@ -162,16 +141,14 @@ class AstrBotCoreLifecycle:
# 初始化事件总线 # 初始化事件总线
self.event_bus = EventBus( self.event_bus = EventBus(
self.event_queue, self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
self.pipeline_scheduler_mapping,
self.astrbot_config_mgr,
) )
# 记录启动时间 # 记录启动时间
self.start_time = int(time.time()) self.start_time = int(time.time())
# 初始化当前任务列表 # 初始化当前任务列表
self.curr_tasks: list[asyncio.Task] = [] self.curr_tasks: List[asyncio.Task] = []
# 根据配置实例化各个平台适配器 # 根据配置实例化各个平台适配器
await self.platform_manager.initialize() await self.platform_manager.initialize()
@@ -179,13 +156,13 @@ class AstrBotCoreLifecycle:
# 初始化关闭控制面板的事件 # 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event() self.dashboard_shutdown_event = asyncio.Event()
def _load(self) -> None: def _load(self):
"""加载事件总线和任务并初始化.""" """加载事件总线和任务并初始化"""
# 创建一个异步任务来执行事件总线的 dispatch() 方法 # 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理 # dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task( event_bus_task = asyncio.create_task(
self.event_bus.dispatch(), self.event_bus.dispatch(), name="event_bus"
name="event_bus",
) )
# 把插件中注册的所有协程函数注册到事件总线中并执行 # 把插件中注册的所有协程函数注册到事件总线中并执行
@@ -196,17 +173,16 @@ class AstrBotCoreLifecycle:
tasks_ = [event_bus_task, *extra_tasks] tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_: for task in tasks_:
self.curr_tasks.append( self.curr_tasks.append(
asyncio.create_task(self._task_wrapper(task), name=task.get_name()), asyncio.create_task(self._task_wrapper(task), name=task.get_name())
) )
self.start_time = int(time.time()) self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task) -> None: async def _task_wrapper(self, task: asyncio.Task):
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常. """异步任务包装器, 用于处理异步任务执行中出现的各种异常
Args: Args:
task (asyncio.Task): 要执行的异步任务 task (asyncio.Task): 要执行的异步任务
""" """
try: try:
await task await task
@@ -219,22 +195,19 @@ class AstrBotCoreLifecycle:
logger.error(f"| {line}") logger.error(f"| {line}")
logger.error("-------") logger.error("-------")
async def start(self) -> None: async def start(self):
"""启动 AstrBot 核心生命周期管理类. """启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
用load加载事件总线和任务并初始化, 执行启动完成事件钩子
"""
self._load() self._load()
logger.info("AstrBot 启动完成。") logger.info("AstrBot 启动完成。")
# 执行启动完成事件钩子 # 执行启动完成事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type( handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAstrBotLoadedEvent, EventType.OnAstrBotLoadedEvent
) )
for handler in handlers: for handler in handlers:
try: try:
logger.info( logger.info(
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
) )
await handler.handler() await handler.handler()
except BaseException: except BaseException:
@@ -243,8 +216,8 @@ class AstrBotCoreLifecycle:
# 同时运行curr_tasks中的所有任务 # 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True) await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self) -> None: async def stop(self):
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器.""" """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
# 请求停止所有正在运行的异步任务 # 请求停止所有正在运行的异步任务
for task in self.curr_tasks: for task in self.curr_tasks:
task.cancel() task.cancel()
@@ -255,12 +228,11 @@ class AstrBotCoreLifecycle:
except Exception as e: except Exception as e:
logger.warning(traceback.format_exc()) logger.warning(traceback.format_exc())
logger.warning( logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。", f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
) )
await self.provider_manager.terminate() await self.provider_manager.terminate()
await self.platform_manager.terminate() await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set() self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束 # 再次遍历curr_tasks等待每个任务真正结束
@@ -272,19 +244,16 @@ class AstrBotCoreLifecycle:
except Exception as e: except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}") logger.error(f"任务 {task.get_name()} 发生错误: {e}")
async def restart(self) -> None: async def restart(self):
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate() await self.provider_manager.terminate()
await self.platform_manager.terminate() await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set() self.dashboard_shutdown_event.set()
threading.Thread( threading.Thread(
target=self.astrbot_updator._reboot, target=self.astrbot_updator._reboot, name="restart", daemon=True
name="restart",
daemon=True,
).start() ).start()
def load_platform(self) -> list[asyncio.Task]: def load_platform(self) -> List[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表""" """加载平台实例并返回所有平台实例的异步任务列表"""
tasks = [] tasks = []
platform_insts = self.platform_manager.get_insts() platform_insts = self.platform_manager.get_insts()
@@ -293,38 +262,36 @@ class AstrBotCoreLifecycle:
asyncio.create_task( asyncio.create_task(
platform_inst.run(), platform_inst.run(),
name=f"{platform_inst.meta().id}({platform_inst.meta().name})", name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
), )
) )
return tasks return tasks
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]: async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
"""加载消息事件流水线调度器. """加载消息事件流水线调度器
Returns: Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
""" """
mapping = {} mapping = {}
for conf_id, ab_config in self.astrbot_config_mgr.confs.items(): for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
scheduler = PipelineScheduler( scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id), PipelineContext(ab_config, self.plugin_manager, conf_id)
) )
await scheduler.initialize() await scheduler.initialize()
mapping[conf_id] = scheduler mapping[conf_id] = scheduler
return mapping return mapping
async def reload_pipeline_scheduler(self, conf_id: str) -> None: async def reload_pipeline_scheduler(self, conf_id: str):
"""重新加载消息事件流水线调度器. """重新加载消息事件流水线调度器
Returns: Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
""" """
ab_config = self.astrbot_config_mgr.confs.get(conf_id) ab_config = self.astrbot_config_mgr.confs.get(conf_id)
if not ab_config: if not ab_config:
raise ValueError(f"配置文件 {conf_id} 不存在") raise ValueError(f"配置文件 {conf_id} 不存在")
scheduler = PipelineScheduler( scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id), PipelineContext(ab_config, self.plugin_manager, conf_id)
) )
await scheduler.initialize() await scheduler.initialize()
self.pipeline_scheduler_mapping[conf_id] = scheduler self.pipeline_scheduler_mapping[conf_id] = scheduler

View File

@@ -1,27 +1,27 @@
import abc import abc
import datetime import datetime
import typing as T import typing as T
from contextlib import asynccontextmanager
from dataclasses import dataclass
from deprecated import deprecated from deprecated import deprecated
from dataclasses import dataclass
from astrbot.core.db.po import (
Stats,
PlatformStat,
ConversationV2,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
)
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from astrbot.core.db.po import (
Attachment,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformStat,
Preference,
Stats,
)
@dataclass @dataclass
class BaseDatabase(abc.ABC): class BaseDatabase(abc.ABC):
"""数据库基类""" """
数据库基类
"""
DATABASE_URL = "" DATABASE_URL = ""
@@ -32,13 +32,12 @@ class BaseDatabase(abc.ABC):
future=True, future=True,
) )
self.AsyncSessionLocal = sessionmaker( self.AsyncSessionLocal = sessionmaker(
self.engine, self.engine, class_=AsyncSession, expire_on_commit=False
class_=AsyncSession,
expire_on_commit=False,
) )
async def initialize(self): async def initialize(self):
"""初始化数据库连接""" """初始化数据库连接"""
pass
@asynccontextmanager @asynccontextmanager
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]: async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
@@ -92,9 +91,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def get_conversations( async def get_conversations(
self, self, user_id: str | None = None, platform_id: str | None = None
user_id: str | None = None,
platform_id: str | None = None,
) -> list[ConversationV2]: ) -> list[ConversationV2]:
"""Get all conversations for a specific user and platform_id(optional). """Get all conversations for a specific user and platform_id(optional).
@@ -109,9 +106,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def get_all_conversations( async def get_all_conversations(
self, self, page: int = 1, page_size: int = 20
page: int = 1,
page_size: int = 20,
) -> list[ConversationV2]: ) -> list[ConversationV2]:
"""Get all conversations with pagination.""" """Get all conversations with pagination."""
... ...
@@ -159,17 +154,12 @@ class BaseDatabase(abc.ABC):
"""Delete a conversation by its ID.""" """Delete a conversation by its ID."""
... ...
@abc.abstractmethod
async def delete_conversations_by_user_id(self, user_id: str) -> None:
"""Delete all conversations for a specific user."""
...
@abc.abstractmethod @abc.abstractmethod
async def insert_platform_message_history( async def insert_platform_message_history(
self, self,
platform_id: str, platform_id: str,
user_id: str, user_id: str,
content: dict, content: list[dict],
sender_id: str | None = None, sender_id: str | None = None,
sender_name: str | None = None, sender_name: str | None = None,
) -> None: ) -> None:
@@ -178,10 +168,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def delete_platform_message_offset( async def delete_platform_message_offset(
self, self, platform_id: str, user_id: str, offset_sec: int = 86400
platform_id: str,
user_id: str,
offset_sec: int = 86400,
) -> None: ) -> None:
"""Delete platform message history records older than the specified offset.""" """Delete platform message history records older than the specified offset."""
... ...
@@ -251,11 +238,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def insert_preference_or_update( async def insert_preference_or_update(
self, self, scope: str, scope_id: str, key: str, value: dict
scope: str,
scope_id: str,
key: str,
value: dict,
) -> Preference: ) -> Preference:
"""Insert a new preference record.""" """Insert a new preference record."""
... ...
@@ -267,10 +250,7 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def get_preferences( async def get_preferences(
self, self, scope: str, scope_id: str | None = None, key: str | None = None
scope: str,
scope_id: str | None = None,
key: str | None = None,
) -> list[Preference]: ) -> list[Preference]:
"""Get all preferences for a specific scope ID or key.""" """Get all preferences for a specific scope ID or key."""
... ...
@@ -302,14 +282,3 @@ class BaseDatabase(abc.ABC):
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]: # async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation.""" # """Get all LLM messages for a specific conversation."""
# ... # ...
@abc.abstractmethod
async def get_session_conversations(
self,
page: int = 1,
page_size: int = 20,
search_query: str | None = None,
platform: str | None = None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
...

View File

@@ -1,33 +1,27 @@
import os import os
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger, sp
from .migra_3_to_4 import ( from .migra_3_to_4 import (
migration_conversation_table, migration_conversation_table,
migration_persona_data,
migration_platform_table, migration_platform_table,
migration_preferences,
migration_webchat_data, migration_webchat_data,
migration_persona_data,
migration_preferences,
) )
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
"""检查是否需要进行数据库迁移 """
检查是否需要进行数据库迁移
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4则需要进行迁移。 如果存在 data_v3.db 并且 preference 中没有 migration_done_v4则需要进行迁移。
""" """
# 仅当 data 目录下存在旧版本数据data_v3.db 文件)时才考虑迁移 data_v3_exists = os.path.exists(get_astrbot_data_path())
data_dir = get_astrbot_data_path() if not data_v3_exists:
data_v3_db = os.path.join(data_dir, "data_v3.db")
if not os.path.exists(data_v3_db):
return False return False
migration_done = await db_helper.get_preference( migration_done = await db_helper.get_preference(
"global", "global", "global", "migration_done_v4"
"global",
"migration_done_v4",
) )
if migration_done: if migration_done:
return False return False
@@ -38,8 +32,9 @@ async def do_migration_v4(
db_helper: BaseDatabase, db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]], platform_id_map: dict[str, dict[str, str]],
astrbot_config: AstrBotConfig, astrbot_config: AstrBotConfig,
) -> None: ):
"""执行数据库迁移 """
执行数据库迁移
迁移旧的 webchat_conversation 表到新的 conversation 表。 迁移旧的 webchat_conversation 表到新的 conversation 表。
迁移旧的 platform 到新的 platform_stats 表。 迁移旧的 platform 到新的 platform_stats 表。
""" """
@@ -58,7 +53,7 @@ async def do_migration_v4(
await migration_webchat_data(db_helper, platform_id_map) await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移 # 执行偏好设置迁移
await migration_preferences(db_helper, platform_id_map) await migration_preferences(db_helper,platform_id_map)
# 执行平台统计表迁移 # 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map) await migration_platform_table(db_helper, platform_id_map)

View File

@@ -1,18 +1,15 @@
import datetime
import json import json
import datetime
from sqlalchemy import text from .. import BaseDatabase
from sqlalchemy.ext.asyncio import AsyncSession from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from .shared_preferences_v3 import sp as sp_v3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger, sp from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
from .. import BaseDatabase from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from .shared_preferences_v3 import sp as sp_v3 from sqlalchemy import text
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
""" """
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。 1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
@@ -21,8 +18,7 @@ from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
def get_platform_id( def get_platform_id(
platform_id_map: dict[str, dict[str, str]], platform_id_map: dict[str, dict[str, str]], old_platform_name: str
old_platform_name: str,
) -> str: ) -> str:
return platform_id_map.get( return platform_id_map.get(
old_platform_name, old_platform_name,
@@ -31,8 +27,7 @@ def get_platform_id(
def get_platform_type( def get_platform_type(
platform_id_map: dict[str, dict[str, str]], platform_id_map: dict[str, dict[str, str]], old_platform_name: str
old_platform_name: str,
) -> str: ) -> str:
return platform_id_map.get( return platform_id_map.get(
old_platform_name, old_platform_name,
@@ -41,15 +36,13 @@ def get_platform_type(
async def migration_conversation_table( async def migration_conversation_table(
db_helper: BaseDatabase, db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
platform_id_map: dict[str, dict[str, str]],
): ):
db_helper_v3 = SQLiteV3DatabaseV3( db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
) )
conversations, total_cnt = db_helper_v3.get_all_conversations( conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page=1, page_size=10000000
page_size=10000000,
) )
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...") logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
@@ -68,14 +61,13 @@ async def migration_conversation_table(
) )
if not conv: if not conv:
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
) )
if ":" not in conv.user_id: if ":" not in conv.user_id:
continue continue
session = MessageSesion.from_str(session_str=conv.user_id) session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id( platform_id = get_platform_id(
platform_id_map, platform_id_map, session.platform_name
session.platform_name,
) )
session.platform_id = platform_id # 更新平台名称为新的 ID session.platform_id = platform_id # 更新平台名称为新的 ID
conv_v2 = ConversationV2( conv_v2 = ConversationV2(
@@ -98,11 +90,10 @@ async def migration_conversation_table(
async def migration_platform_table( async def migration_platform_table(
db_helper: BaseDatabase, db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
platform_id_map: dict[str, dict[str, str]],
): ):
db_helper_v3 = SQLiteV3DatabaseV3( db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
) )
secs_from_2023_4_10_to_now = ( secs_from_2023_4_10_to_now = (
datetime.datetime.now(datetime.timezone.utc) datetime.datetime.now(datetime.timezone.utc)
@@ -143,12 +134,10 @@ async def migration_platform_table(
if cnt == 0: if cnt == 0:
continue continue
platform_id = get_platform_id( platform_id = get_platform_id(
platform_id_map, platform_id_map, platform_stats_v3[idx].name
platform_stats_v3[idx].name,
) )
platform_type = get_platform_type( platform_type = get_platform_type(
platform_id_map, platform_id_map, platform_stats_v3[idx].name
platform_stats_v3[idx].name,
) )
try: try:
await dbsession.execute( await dbsession.execute(
@@ -160,8 +149,7 @@ async def migration_platform_table(
"""), """),
{ {
"timestamp": datetime.datetime.fromtimestamp( "timestamp": datetime.datetime.fromtimestamp(
bucket_end, bucket_end, tz=datetime.timezone.utc
tz=datetime.timezone.utc,
), ),
"platform_id": platform_id, "platform_id": platform_id,
"platform_type": platform_type, "platform_type": platform_type,
@@ -177,16 +165,14 @@ async def migration_platform_table(
async def migration_webchat_data( async def migration_webchat_data(
db_helper: BaseDatabase, db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
platform_id_map: dict[str, dict[str, str]],
): ):
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3( db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
) )
conversations, total_cnt = db_helper_v3.get_all_conversations( conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page=1, page_size=10000000
page_size=10000000,
) )
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...") logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
@@ -205,7 +191,7 @@ async def migration_webchat_data(
) )
if not conv: if not conv:
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
) )
if ":" in conv.user_id: if ":" in conv.user_id:
continue continue
@@ -232,10 +218,10 @@ async def migration_webchat_data(
async def migration_persona_data( async def migration_persona_data(
db_helper: BaseDatabase, db_helper: BaseDatabase, astrbot_config: AstrBotConfig
astrbot_config: AstrBotConfig,
): ):
"""迁移 Persona 数据到新的表中。 """
迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
""" """
v3_persona_config: list[dict] = astrbot_config.get("persona", []) v3_persona_config: list[dict] = astrbot_config.get("persona", [])
@@ -250,15 +236,14 @@ async def migration_persona_data(
try: try:
begin_dialogs = persona.get("begin_dialogs", []) begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", []) mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
parts = [] mood_prompt = ""
user_turn = True user_turn = True
for mood_dialog in mood_imitation_dialogs: for mood_dialog in mood_imitation_dialogs:
if user_turn: if user_turn:
parts.append(f"A: {mood_dialog}\n") mood_prompt += f"A: {mood_dialog}\n"
else: else:
parts.append(f"B: {mood_dialog}\n") mood_prompt += f"B: {mood_dialog}\n"
user_turn = not user_turn user_turn = not user_turn
mood_prompt = "".join(parts)
system_prompt = persona.get("prompt", "") system_prompt = persona.get("prompt", "")
if mood_prompt: if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}" system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
@@ -268,15 +253,14 @@ async def migration_persona_data(
begin_dialogs=begin_dialogs, begin_dialogs=begin_dialogs,
) )
logger.info( logger.info(
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。", f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
) )
except Exception as e: except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}") logger.error(f"解析 Persona 配置失败:{e}")
async def migration_preferences( async def migration_preferences(
db_helper: BaseDatabase, db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
platform_id_map: dict[str, dict[str, str]],
): ):
# 1. global scope migration # 1. global scope migration
keys = [ keys = [
@@ -345,13 +329,10 @@ async def migration_preferences(
for provider_type, provider_id in perf.items(): for provider_type, provider_id in perf.items():
await sp.put_async( await sp.put_async(
"umo", "umo", str(session), f"provider_perf_{provider_type}", provider_id
str(session),
f"provider_perf_{provider_type}",
provider_id,
) )
logger.info( logger.info(
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}", f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
) )
except Exception as e: except Exception as e:
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True) logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)

View File

@@ -1,44 +0,0 @@
from astrbot.api import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.umop_config_router import UmopConfigRouter
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
abconf_data = acm.abconf_data
if not isinstance(abconf_data, dict):
# should be unreachable
logger.warning(
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}",
)
return
# 如果任何一项带有 umop则说明需要迁移
need_migration = False
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
need_migration = True
break
if not need_migration:
return
logger.info("Starting migration from version 4.5 to 4.6")
# extract umo->conf_id mapping
umo_to_conf_id = {}
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
umop_ls = conf_info.pop("umop")
if not isinstance(umop_ls, list):
continue
for umo in umop_ls:
if isinstance(umo, str) and umo not in umo_to_conf_id:
umo_to_conf_id[umo] = conf_id
# update the abconf data
await sp.global_put("abconf_mapping", abconf_data)
# update the umop config router
await ucr.update_routing_data(umo_to_conf_id)
logger.info("Migration from version 45 to 46 completed successfully")

View File

@@ -1,12 +1,10 @@
import json import json
import os import os
from typing import TypeVar from typing import TypeVar
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT") _VT = TypeVar("_VT")
class SharedPreferences: class SharedPreferences:
def __init__(self, path=None): def __init__(self, path=None):
if path is None: if path is None:
@@ -17,7 +15,7 @@ class SharedPreferences:
def _load_preferences(self): def _load_preferences(self):
if os.path.exists(self.path): if os.path.exists(self.path):
try: try:
with open(self.path) as f: with open(self.path, "r") as f:
return json.load(f) return json.load(f)
except json.JSONDecodeError: except json.JSONDecodeError:
os.remove(self.path) os.remove(self.path)
@@ -44,5 +42,4 @@ class SharedPreferences:
self._data.clear() self._data.clear()
self._save_preferences() self._save_preferences()
sp = SharedPreferences() sp = SharedPreferences()

View File

@@ -1,10 +1,8 @@
import sqlite3 import sqlite3
import time import time
from dataclasses import dataclass
from typing import Any
from astrbot.core.db.po import Platform, Stats from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass @dataclass
class Conversation: class Conversation:
@@ -78,7 +76,7 @@ PRAGMA encoding = 'UTF-8';
""" """
class SQLiteDatabase: class SQLiteDatabase():
def __init__(self, db_path: str) -> None: def __init__(self, db_path: str) -> None:
super().__init__() super().__init__()
self.db_path = db_path self.db_path = db_path
@@ -95,7 +93,7 @@ class SQLiteDatabase:
c.execute( c.execute(
""" """
PRAGMA table_info(webchat_conversation) PRAGMA table_info(webchat_conversation)
""", """
) )
res = c.fetchall() res = c.fetchall()
has_title = False has_title = False
@@ -109,14 +107,14 @@ class SQLiteDatabase:
c.execute( c.execute(
""" """
ALTER TABLE webchat_conversation ADD COLUMN title TEXT; ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
""", """
) )
self.conn.commit() self.conn.commit()
if not has_persona_id: if not has_persona_id:
c.execute( c.execute(
""" """
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
""", """
) )
self.conn.commit() self.conn.commit()
@@ -127,7 +125,7 @@ class SQLiteDatabase:
conn.text_factory = str conn.text_factory = str
return conn return conn
def _exec_sql(self, sql: str, params: tuple = None): def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn conn = self.conn
try: try:
c = self.conn.cursor() c = self.conn.cursor()
@@ -175,7 +173,7 @@ class SQLiteDatabase:
""" """
SELECT * FROM platform SELECT * FROM platform
""" """
+ where_clause, + where_clause
) )
platform = [] platform = []
@@ -195,7 +193,7 @@ class SQLiteDatabase:
c.execute( c.execute(
""" """
SELECT SUM(count) FROM platform SELECT SUM(count) FROM platform
""", """
) )
res = c.fetchone() res = c.fetchone()
c.close() c.close()
@@ -215,7 +213,7 @@ class SQLiteDatabase:
SELECT name, SUM(count), timestamp FROM platform SELECT name, SUM(count), timestamp FROM platform
""" """
+ where_clause + where_clause
+ " GROUP BY name", + " GROUP BY name"
) )
platform = [] platform = []
@@ -243,7 +241,7 @@ class SQLiteDatabase:
c.close() c.close()
if not res: if not res:
return None return
return Conversation(*res) return Conversation(*res)
@@ -258,7 +256,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at), (user_id, cid, history, updated_at, created_at),
) )
def get_conversations(self, user_id: str) -> tuple: def get_conversations(self, user_id: str) -> Tuple:
try: try:
c = self.conn.cursor() c = self.conn.cursor()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
@@ -281,7 +279,7 @@ class SQLiteDatabase:
title = row[3] title = row[3]
persona_id = row[4] persona_id = row[4]
conversations.append( conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id), Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
) )
return conversations return conversations
@@ -320,10 +318,8 @@ class SQLiteDatabase:
) )
def get_all_conversations( def get_all_conversations(
self, self, page: int = 1, page_size: int = 20
page: int = 1, ) -> Tuple[List[Dict[str, Any]], int]:
page_size: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序""" """获取所有对话,支持分页,按更新时间降序排序"""
try: try:
c = self.conn.cursor() c = self.conn.cursor()
@@ -369,7 +365,7 @@ class SQLiteDatabase:
"persona_id": persona_id or "", "persona_id": persona_id or "",
"created_at": created_at or 0, "created_at": created_at or 0,
"updated_at": updated_at or 0, "updated_at": updated_at or 0,
}, }
) )
return conversations, total_count return conversations, total_count
@@ -384,12 +380,12 @@ class SQLiteDatabase:
self, self,
page: int = 1, page: int = 1,
page_size: int = 20, page_size: int = 20,
platforms: list[str] | None = None, platforms: List[str] = None,
message_types: list[str] | None = None, message_types: List[str] = None,
search_query: str | None = None, search_query: str = None,
exclude_ids: list[str] | None = None, exclude_ids: List[str] = None,
exclude_platforms: list[str] | None = None, exclude_platforms: List[str] = None,
) -> tuple[list[dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表""" """获取筛选后的对话列表"""
try: try:
c = self.conn.cursor() c = self.conn.cursor()
@@ -425,7 +421,7 @@ class SQLiteDatabase:
if search_query: if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8") search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append( where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)", "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
) )
search_param = f"%{search_query}%" search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param]) params.extend([search_param, search_param, search_param, search_param])
@@ -485,7 +481,7 @@ class SQLiteDatabase:
"persona_id": persona_id or "", "persona_id": persona_id or "",
"created_at": created_at or 0, "created_at": created_at or 0,
"updated_at": updated_at or 0, "updated_at": updated_at or 0,
}, }
) )
return conversations, total_count return conversations, total_count

View File

@@ -1,15 +1,15 @@
import uuid import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TypedDict
from datetime import datetime, timezone
from dataclasses import dataclass, field
from sqlmodel import ( from sqlmodel import (
JSON,
Field,
SQLModel, SQLModel,
Text, Text,
JSON,
UniqueConstraint, UniqueConstraint,
Field,
) )
from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True): class PlatformStat(SQLModel, table=True):
@@ -40,8 +40,7 @@ class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" __tablename__ = "conversations"
inner_conversation_id: int = Field( inner_conversation_id: int = Field(
primary_key=True, primary_key=True, sa_column_kwargs={"autoincrement": True}
sa_column_kwargs={"autoincrement": True},
) )
conversation_id: str = Field( conversation_id: str = Field(
max_length=36, max_length=36,
@@ -51,14 +50,14 @@ class ConversationV2(SQLModel, table=True):
) )
platform_id: str = Field(nullable=False) platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) user_id: str = Field(nullable=False)
content: list | None = Field(default=None, sa_type=JSON) content: Optional[list] = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field( updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc), default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
) )
title: str | None = Field(default=None, max_length=255) title: Optional[str] = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None) persona_id: Optional[str] = Field(default=None)
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
@@ -76,16 +75,12 @@ class Persona(SQLModel, table=True):
__tablename__ = "personas" __tablename__ = "personas"
id: int | None = Field( id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
)
persona_id: str = Field(max_length=255, nullable=False) persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False) system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: list | None = Field(default=None, sa_type=JSON) begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with""" """a list of strings, each representing a dialog to start with"""
tools: list | None = Field(default=None, sa_type=JSON) tools: Optional[list] = Field(default=None, sa_type=JSON)
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names.""" """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field( updated_at: datetime = Field(
@@ -107,9 +102,7 @@ class Preference(SQLModel, table=True):
__tablename__ = "preferences" __tablename__ = "preferences"
id: int | None = Field( id: int | None = Field(
default=None, default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
primary_key=True,
sa_column_kwargs={"autoincrement": True},
) )
scope: str = Field(nullable=False) scope: str = Field(nullable=False)
"""Scope of the preference, such as 'global', 'umo', 'plugin'.""" """Scope of the preference, such as 'global', 'umo', 'plugin'."""
@@ -142,16 +135,12 @@ class PlatformMessageHistory(SQLModel, table=True):
__tablename__ = "platform_message_history" __tablename__ = "platform_message_history"
id: int | None = Field( id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
)
platform_id: str = Field(nullable=False) platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: str | None = Field(default=None) # ID of the sender in the platform sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
sender_name: str | None = Field( sender_name: Optional[str] = Field(
default=None, default=None
) # Name of the sender in the platform ) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@@ -169,10 +158,8 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments" __tablename__ = "attachments"
inner_attachment_id: int | None = Field( inner_attachment_id: int = Field(
primary_key=True, primary_key=True, sa_column_kwargs={"autoincrement": True}
sa_column_kwargs={"autoincrement": True},
default=None,
) )
attachment_id: str = Field( attachment_id: str = Field(
max_length=36, max_length=36,

View File

@@ -1,27 +1,23 @@
import asyncio import asyncio
import threading
import typing as T import typing as T
import threading
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment,
ConversationV2, ConversationV2,
Persona,
PlatformMessageHistory,
PlatformStat, PlatformStat,
PlatformMessageHistory,
Attachment,
Persona,
Preference, Preference,
Stats as DeprecatedStats,
Platform as DeprecatedPlatformStat,
SQLModel, SQLModel,
) )
from astrbot.core.db.po import (
Platform as DeprecatedPlatformStat, from sqlalchemy import select, update, delete, text
) from sqlalchemy.ext.asyncio import AsyncSession
from astrbot.core.db.po import ( from sqlalchemy.sql import func
Stats as DeprecatedStats,
)
NOT_GIVEN = T.TypeVar("NOT_GIVEN") NOT_GIVEN = T.TypeVar("NOT_GIVEN")
@@ -37,12 +33,6 @@ class SQLiteDatabase(BaseDatabase):
"""Initialize the database by creating tables if they do not exist.""" """Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn: async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all) await conn.run_sync(SQLModel.metadata.create_all)
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit() await conn.commit()
# ==== # ====
@@ -51,10 +41,10 @@ class SQLiteDatabase(BaseDatabase):
async def insert_platform_stats( async def insert_platform_stats(
self, self,
platform_id, platform_id: str,
platform_type, platform_type: str,
count=1, count: int = 1,
timestamp=None, timestamp: datetime = None,
) -> None: ) -> None:
"""Insert a new platform statistic record.""" """Insert a new platform statistic record."""
async with self.get_db() as session: async with self.get_db() as session:
@@ -62,9 +52,7 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin(): async with session.begin():
if timestamp is None: if timestamp is None:
timestamp = datetime.now().replace( timestamp = datetime.now().replace(
minute=0, minute=0, second=0, microsecond=0
second=0,
microsecond=0,
) )
current_hour = timestamp current_hour = timestamp
await session.execute( await session.execute(
@@ -87,14 +75,12 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
result = await session.execute( result = await session.execute(
select(func.count(col(PlatformStat.platform_id))).select_from( select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
PlatformStat,
),
) )
count = result.scalar_one_or_none() count = result.scalar_one_or_none()
return count if count is not None else 0 return count if count is not None else 0
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id.""" """Get platform statistics within the specified offset in seconds and group by platform_id."""
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
@@ -109,7 +95,7 @@ class SQLiteDatabase(BaseDatabase):
"""), """),
{"start_time": start_time}, {"start_time": start_time},
) )
return list(result.scalars().all()) return result.scalars().all()
# ==== # ====
# Conversation Management # Conversation Management
@@ -125,7 +111,7 @@ class SQLiteDatabase(BaseDatabase):
if platform_id: if platform_id:
query = query.where(ConversationV2.platform_id == platform_id) query = query.where(ConversationV2.platform_id == platform_id)
# order by # order by
query = query.order_by(desc(ConversationV2.created_at)) query = query.order_by(ConversationV2.created_at.desc())
result = await session.execute(query) result = await session.execute(query)
return result.scalars().all() return result.scalars().all()
@@ -143,9 +129,9 @@ class SQLiteDatabase(BaseDatabase):
offset = (page - 1) * page_size offset = (page - 1) * page_size
result = await session.execute( result = await session.execute(
select(ConversationV2) select(ConversationV2)
.order_by(desc(ConversationV2.created_at)) .order_by(ConversationV2.created_at.desc())
.offset(offset) .offset(offset)
.limit(page_size), .limit(page_size)
) )
return result.scalars().all() return result.scalars().all()
@@ -164,26 +150,11 @@ class SQLiteDatabase(BaseDatabase):
if platform_ids: if platform_ids:
base_query = base_query.where( base_query = base_query.where(
col(ConversationV2.platform_id).in_(platform_ids), ConversationV2.platform_id.in_(platform_ids)
) )
if search_query: if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where( base_query = base_query.where(
or_( ConversationV2.title.ilike(f"%{search_query}%")
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
),
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
) )
# Get total count matching the filters # Get total count matching the filters
@@ -194,7 +165,7 @@ class SQLiteDatabase(BaseDatabase):
# Get paginated results # Get paginated results
offset = (page - 1) * page_size offset = (page - 1) * page_size
result_query = ( result_query = (
base_query.order_by(desc(ConversationV2.created_at)) base_query.order_by(ConversationV2.created_at.desc())
.offset(offset) .offset(offset)
.limit(page_size) .limit(page_size)
) )
@@ -240,7 +211,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession session: AsyncSession
async with session.begin(): async with session.begin():
query = update(ConversationV2).where( query = update(ConversationV2).where(
col(ConversationV2.conversation_id) == cid, ConversationV2.conversation_id == cid
) )
values = {} values = {}
if title is not None: if title is not None:
@@ -250,7 +221,7 @@ class SQLiteDatabase(BaseDatabase):
if content is not None: if content is not None:
values["content"] = content values["content"] = content
if not values: if not values:
return None return
query = query.values(**values) query = query.values(**values)
await session.execute(query) await session.execute(query)
return await self.get_conversation_by_id(cid) return await self.get_conversation_by_id(cid)
@@ -260,130 +231,9 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession session: AsyncSession
async with session.begin(): async with session.begin():
await session.execute( await session.execute(
delete(ConversationV2).where( delete(ConversationV2).where(ConversationV2.conversation_id == cid)
col(ConversationV2.conversation_id) == cid,
),
) )
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(
col(ConversationV2.user_id) == user_id
),
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id",
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona,
col(ConversationV2.persona_id) == Persona.persona_id,
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
),
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern),
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history( async def insert_platform_message_history(
self, self,
platform_id, platform_id,
@@ -407,10 +257,7 @@ class SQLiteDatabase(BaseDatabase):
return new_history return new_history
async def delete_platform_message_offset( async def delete_platform_message_offset(
self, self, platform_id, user_id, offset_sec=86400
platform_id,
user_id,
offset_sec=86400,
): ):
"""Delete platform message history records older than the specified offset.""" """Delete platform message history records older than the specified offset."""
async with self.get_db() as session: async with self.get_db() as session:
@@ -420,18 +267,14 @@ class SQLiteDatabase(BaseDatabase):
cutoff_time = now - timedelta(seconds=offset_sec) cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute( await session.execute(
delete(PlatformMessageHistory).where( delete(PlatformMessageHistory).where(
col(PlatformMessageHistory.platform_id) == platform_id, PlatformMessageHistory.platform_id == platform_id,
col(PlatformMessageHistory.user_id) == user_id, PlatformMessageHistory.user_id == user_id,
col(PlatformMessageHistory.created_at) < cutoff_time, PlatformMessageHistory.created_at < cutoff_time,
), )
) )
async def get_platform_message_history( async def get_platform_message_history(
self, self, platform_id, user_id, page=1, page_size=20
platform_id,
user_id,
page=1,
page_size=20,
): ):
"""Get platform message history records.""" """Get platform message history records."""
async with self.get_db() as session: async with self.get_db() as session:
@@ -443,7 +286,7 @@ class SQLiteDatabase(BaseDatabase):
PlatformMessageHistory.platform_id == platform_id, PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id, PlatformMessageHistory.user_id == user_id,
) )
.order_by(desc(PlatformMessageHistory.created_at)) .order_by(PlatformMessageHistory.created_at.desc())
) )
result = await session.execute(query.offset(offset).limit(page_size)) result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all() return result.scalars().all()
@@ -465,16 +308,12 @@ class SQLiteDatabase(BaseDatabase):
"""Get an attachment by its ID.""" """Get an attachment by its ID."""
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
query = select(Attachment).where(Attachment.attachment_id == attachment_id) query = select(Attachment).where(Attachment.id == attachment_id)
result = await session.execute(query) result = await session.execute(query)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def insert_persona( async def insert_persona(
self, self, persona_id, system_prompt, begin_dialogs=None, tools=None
persona_id,
system_prompt,
begin_dialogs=None,
tools=None,
): ):
"""Insert a new persona record.""" """Insert a new persona record."""
async with self.get_db() as session: async with self.get_db() as session:
@@ -506,17 +345,13 @@ class SQLiteDatabase(BaseDatabase):
return result.scalars().all() return result.scalars().all()
async def update_persona( async def update_persona(
self, self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
persona_id,
system_prompt=None,
begin_dialogs=None,
tools=NOT_GIVEN,
): ):
"""Update a persona's system prompt or begin dialogs.""" """Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
async with session.begin(): async with session.begin():
query = update(Persona).where(col(Persona.persona_id) == persona_id) query = update(Persona).where(Persona.persona_id == persona_id)
values = {} values = {}
if system_prompt is not None: if system_prompt is not None:
values["system_prompt"] = system_prompt values["system_prompt"] = system_prompt
@@ -525,7 +360,7 @@ class SQLiteDatabase(BaseDatabase):
if tools is not NOT_GIVEN: if tools is not NOT_GIVEN:
values["tools"] = tools values["tools"] = tools
if not values: if not values:
return None return
query = query.values(**values) query = query.values(**values)
await session.execute(query) await session.execute(query)
return await self.get_persona_by_id(persona_id) return await self.get_persona_by_id(persona_id)
@@ -536,7 +371,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession session: AsyncSession
async with session.begin(): async with session.begin():
await session.execute( await session.execute(
delete(Persona).where(col(Persona.persona_id) == persona_id), delete(Persona).where(Persona.persona_id == persona_id)
) )
async def insert_preference_or_update(self, scope, scope_id, key, value): async def insert_preference_or_update(self, scope, scope_id, key, value):
@@ -555,10 +390,7 @@ class SQLiteDatabase(BaseDatabase):
existing_preference.value = value existing_preference.value = value
else: else:
new_preference = Preference( new_preference = Preference(
scope=scope, scope=scope, scope_id=scope_id, key=key, value=value
scope_id=scope_id,
key=key,
value=value,
) )
session.add(new_preference) session.add(new_preference)
return existing_preference or new_preference return existing_preference or new_preference
@@ -594,10 +426,10 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin(): async with session.begin():
await session.execute( await session.execute(
delete(Preference).where( delete(Preference).where(
col(Preference.scope) == scope, Preference.scope == scope,
col(Preference.scope_id) == scope_id, Preference.scope_id == scope_id,
col(Preference.key) == key, Preference.key == key,
), )
) )
await session.commit() await session.commit()
@@ -608,9 +440,8 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin(): async with session.begin():
await session.execute( await session.execute(
delete(Preference).where( delete(Preference).where(
col(Preference.scope) == scope, Preference.scope == scope, Preference.scope_id == scope_id
col(Preference.scope_id) == scope_id, )
),
) )
await session.commit() await session.commit()
@@ -627,7 +458,7 @@ class SQLiteDatabase(BaseDatabase):
now = datetime.now() now = datetime.now()
start_time = now - timedelta(seconds=offset_sec) start_time = now - timedelta(seconds=offset_sec)
result = await session.execute( result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time), select(PlatformStat).where(PlatformStat.timestamp >= start_time)
) )
all_datas = result.scalars().all() all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats() deprecated_stats = DeprecatedStats()
@@ -636,8 +467,8 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat( DeprecatedPlatformStat(
name=data.platform_id, name=data.platform_id,
count=data.count, count=data.count,
timestamp=int(data.timestamp.timestamp()), timestamp=data.timestamp.timestamp(),
), )
) )
return deprecated_stats return deprecated_stats
@@ -659,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
result = await session.execute( result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat), select(func.sum(PlatformStat.count)).select_from(PlatformStat)
) )
total_count = result.scalar_one_or_none() total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0 return total_count if total_count is not None else 0
@@ -685,7 +516,7 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute( result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count)) select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time) .where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id), .group_by(PlatformStat.platform_id)
) )
grouped_stats = result.all() grouped_stats = result.all()
deprecated_stats = DeprecatedStats() deprecated_stats = DeprecatedStats()
@@ -694,8 +525,8 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat( DeprecatedPlatformStat(
name=platform_id, name=platform_id,
count=count, count=count,
timestamp=int(start_time.timestamp()), timestamp=start_time.timestamp(),
), )
) )
return deprecated_stats return deprecated_stats

View File

@@ -10,47 +10,22 @@ class Result:
class BaseVecDB: class BaseVecDB:
async def initialize(self): async def initialize(self):
"""初始化向量数据库""" """
初始化向量数据库
"""
pass
@abc.abstractmethod @abc.abstractmethod
async def insert( async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
self, """
content: str, 插入一条文本和其对应向量,自动生成 ID 并保持一致性。
metadata: dict | None = None,
id: str | None = None,
) -> int:
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
...
@abc.abstractmethod
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> int:
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
""" """
... ...
@abc.abstractmethod @abc.abstractmethod
async def retrieve( async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
self, """
query: str, 搜索最相似的文档。
top_k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""搜索最相似的文档。
Args: Args:
query (str): 查询文本 query (str): 查询文本
top_k (int): 返回的最相似文档的数量 top_k (int): 返回的最相似文档的数量
@@ -61,13 +36,11 @@ class BaseVecDB:
@abc.abstractmethod @abc.abstractmethod
async def delete(self, doc_id: str) -> bool: async def delete(self, doc_id: str) -> bool:
"""删除指定文档。 """
删除指定文档。
Args: Args:
doc_id (str): 要删除的文档 ID doc_id (str): 要删除的文档 ID
Returns: Returns:
bool: 删除是否成功 bool: 删除是否成功
""" """
... ...
@abc.abstractmethod
async def close(self): ...

View File

@@ -1,232 +1,59 @@
import json import aiosqlite
import os import os
from contextlib import asynccontextmanager
from datetime import datetime
from sqlalchemy import Column, Text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, MetaData, SQLModel, col, func, select, text
from astrbot.core import logger
class BaseDocModel(SQLModel, table=False):
metadata = MetaData()
class Document(BaseDocModel, table=True):
"""SQLModel for documents table."""
__tablename__ = "documents" # type: ignore
id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
doc_id: str = Field(nullable=False)
text: str = Field(nullable=False)
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
created_at: datetime | None = Field(default=None)
updated_at: datetime | None = Field(default=None)
class DocumentStorage: class DocumentStorage:
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.connection = None
self.engine: AsyncEngine | None = None
self.async_session_maker: sessionmaker | None = None
self.sqlite_init_path = os.path.join( self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), os.path.dirname(__file__), "sqlite_init.sql"
"sqlite_init.sql",
) )
async def initialize(self): async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist.""" """Initialize the SQLite database and create the documents table if it doesn't exist."""
await self.connect() if not os.path.exists(self.db_path):
async with self.engine.begin() as conn: # type: ignore await self.connect()
# Create tables using SQLModel async with self.connection.cursor() as cursor:
await conn.run_sync(BaseDocModel.metadata.create_all) with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
try: await cursor.executescript(sql_script)
await conn.execute( await self.connection.commit()
text( else:
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT " await self.connect()
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED",
),
)
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN user_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED",
),
)
# Create indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)",
),
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)",
),
)
except BaseException:
pass
await conn.commit()
async def connect(self): async def connect(self):
"""Connect to the SQLite database.""" """Connect to the SQLite database."""
if self.engine is None: self.connection = await aiosqlite.connect(self.db_path)
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.async_session_maker = sessionmaker(
self.engine, # type: ignore
class_=AsyncSession,
expire_on_commit=False,
) # type: ignore
@asynccontextmanager async def get_documents(self, metadata_filters: dict, ids: list = None):
async def get_session(self):
"""Context manager for database sessions."""
async with self.async_session_maker() as session: # type: ignore
yield session
async def get_documents(
self,
metadata_filters: dict,
ids: list | None = None,
offset: int | None = 0,
limit: int | None = 100,
) -> list[dict]:
"""Retrieve documents by metadata filters and ids. """Retrieve documents by metadata filters and ids.
Args: Args:
metadata_filters (dict): The metadata filters to apply. metadata_filters (dict): The metadata filters to apply.
ids (list | None): Optional list of document IDs to filter.
offset (int | None): Offset for pagination.
limit (int | None): Limit for pagination.
Returns: Returns:
list: The list of documents that match the filters. list: The list of document IDs(primary key, not doc_id) that match the filters.
""" """
if self.engine is None: # metadata filter -> SQL WHERE clause
logger.warning( where_clauses = []
"Database connection is not initialized, returning empty result", values = []
) for key, val in metadata_filters.items():
return [] where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
async with self.get_session() as session: result = []
query = select(Document) async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
for key, val in metadata_filters.items(): await cursor.execute(sql, values)
query = query.where( for row in await cursor.fetchall():
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), result.append(await self.tuple_to_dict(row))
).params(**{f"filter_{key}": val}) return result
if ids is not None and len(ids) > 0:
valid_ids = [int(i) for i in ids if i != -1]
if valid_ids:
query = query.where(col(Document.id).in_(valid_ids))
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await session.execute(query)
documents = result.scalars().all()
return [self._document_to_dict(doc) for doc in documents]
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
"""Insert a single document and return its integer ID.
Args:
doc_id (str): The document ID (UUID string).
text (str): The document text.
metadata (dict): The document metadata.
Returns:
int: The integer ID of the inserted document.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(document)
await session.flush() # Flush to get the ID
return document.id # type: ignore
async def insert_documents_batch(
self,
doc_ids: list[str],
texts: list[str],
metadatas: list[dict],
) -> list[int]:
"""Batch insert documents and return their integer IDs.
Args:
doc_ids (list[str]): List of document IDs (UUID strings).
texts (list[str]): List of document texts.
metadatas (list[dict]): List of document metadata.
Returns:
list[int]: List of integer IDs of the inserted documents.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
import json
documents = []
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
documents.append(document)
session.add(document)
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
"""Delete a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to delete.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session, session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
await session.delete(document)
async def get_document_by_doc_id(self, doc_id: str): async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id. """Retrieve a document by its doc_id.
@@ -235,134 +62,40 @@ class DocumentStorage:
doc_id (str): The doc_id of the document to retrieve. doc_id (str): The doc_id of the document to retrieve.
Returns: Returns:
dict: The document data or None if not found. dict: The document data.
""" """
assert self.engine is not None, "Database connection is not initialized." async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
async with self.get_session() as session: row = await cursor.fetchone()
query = select(Document).where(col(Document.doc_id) == doc_id) if row:
result = await session.execute(query) return await self.tuple_to_dict(row)
document = result.scalar_one_or_none() else:
return None
if document:
return self._document_to_dict(document)
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str): async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Update a document by its doc_id. """Retrieve a document by its doc_id.
Args: Args:
doc_id (str): The doc_id. doc_id (str): The doc_id.
new_text (str): The new text to update the document with. new_text (str): The new text to update the document with.
""" """
assert self.engine is not None, "Database connection is not initialized." async with self.connection.cursor() as cursor:
await cursor.execute(
async with self.get_session() as session, session.begin(): "UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
"""Delete documents by their metadata filters.
Args:
metadata_filters (dict): The metadata filters to apply.
"""
if self.engine is None:
logger.warning(
"Database connection is not initialized, skipping delete operation",
) )
return await self.connection.commit()
async with self.get_session() as session, session.begin():
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
).params(**{f"filter_{key}": val})
result = await session.execute(query)
documents = result.scalars().all()
for doc in documents:
await session.delete(doc)
async def count_documents(self, metadata_filters: dict | None = None) -> int:
"""Count documents in the database.
Args:
metadata_filters (dict | None): Metadata filters to apply.
Returns:
int: The count of documents.
"""
if self.engine is None:
logger.warning("Database connection is not initialized, returning 0")
return 0
async with self.get_session() as session:
query = select(func.count(col(Document.id)))
if metadata_filters:
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
).params(**{f"filter_{key}": val})
result = await session.execute(query)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_user_ids(self) -> list[str]: async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table. """Retrieve all user IDs from the documents table.
Returns: Returns:
list: A list of user IDs. list: A list of user IDs.
""" """
assert self.engine is not None, "Database connection is not initialized." async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
async with self.get_session() as session: rows = await cursor.fetchall()
query = text(
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL",
)
result = await session.execute(query)
rows = result.fetchall()
return [row[0] for row in rows] return [row[0] for row in rows]
def _document_to_dict(self, document: Document) -> dict:
"""Convert a Document model to a dictionary.
Args:
document (Document): The document to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": document.id,
"doc_id": document.doc_id,
"text": document.text,
"metadata": document.metadata_,
"created_at": document.created_at.isoformat()
if isinstance(document.created_at, datetime)
else document.created_at,
"updated_at": document.updated_at.isoformat()
if isinstance(document.updated_at, datetime)
else document.updated_at,
}
async def tuple_to_dict(self, row): async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary. """Convert a tuple to a dictionary.
@@ -371,9 +104,6 @@ class DocumentStorage:
Returns: Returns:
dict: The converted dictionary. dict: The converted dictionary.
Note: This method is kept for backward compatibility but is no longer used internally.
""" """
return { return {
"id": row[0], "id": row[0],
@@ -386,7 +116,6 @@ class DocumentStorage:
async def close(self): async def close(self):
"""Close the connection to the SQLite database.""" """Close the connection to the SQLite database."""
if self.engine: if self.connection:
await self.engine.dispose() await self.connection.close()
self.engine = None self.connection = None
self.async_session_maker = None

View File

@@ -2,15 +2,14 @@ try:
import faiss import faiss
except ModuleNotFoundError: except ModuleNotFoundError:
raise ImportError( raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。", "faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。"
) )
import os import os
import numpy as np import numpy as np
class EmbeddingStorage: class EmbeddingStorage:
def __init__(self, dimension: int, path: str | None = None): def __init__(self, dimension: int, path: str = None):
self.dimension = dimension self.dimension = dimension
self.path = path self.path = path
self.index = None self.index = None
@@ -19,6 +18,7 @@ class EmbeddingStorage:
else: else:
base_index = faiss.IndexFlatL2(dimension) base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index) self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int): async def insert(self, vector: np.ndarray, id: int):
"""插入向量 """插入向量
@@ -28,32 +28,13 @@ class EmbeddingStorage:
id (int): 向量的ID id (int): 向量的ID
Raises: Raises:
ValueError: 如果向量的维度与存储的维度不匹配 ValueError: 如果向量的维度与存储的维度不匹配
""" """
assert self.index is not None, "FAISS index is not initialized."
if vector.shape[0] != self.dimension: if vector.shape[0] != self.dimension:
raise ValueError( raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}", f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
) )
self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
await self.save_index() self.storage[id] = vector
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
"""批量插入向量
Args:
vectors (np.ndarray): 要插入的向量数组
ids (list[int]): 向量的ID列表
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vectors.shape[1] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
)
self.index.add_with_ids(vectors, np.array(ids))
await self.save_index() await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple: async def search(self, vector: np.ndarray, k: int) -> tuple:
@@ -64,30 +45,15 @@ class EmbeddingStorage:
k (int): 返回的最相似向量的数量 k (int): 返回的最相似向量的数量
Returns: Returns:
tuple: (距离, 索引) tuple: (距离, 索引)
""" """
assert self.index is not None, "FAISS index is not initialized."
faiss.normalize_L2(vector) faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k) distances, indices = self.index.search(vector, k)
return distances, indices return distances, indices
async def delete(self, ids: list[int]):
"""删除向量
Args:
ids (list[int]): 要删除的向量ID列表
"""
assert self.index is not None, "FAISS index is not initialized."
id_array = np.array(ids, dtype=np.int64)
self.index.remove_ids(id_array)
await self.save_index()
async def save_index(self): async def save_index(self):
"""保存索引 """保存索引
Args: Args:
path (str): 保存索引的路径 path (str): 保存索引的路径
""" """
faiss.write_index(self.index, self.path) faiss.write_index(self.index, self.path)

View File

@@ -1,18 +1,17 @@
import time
import uuid import uuid
import json
import numpy as np import numpy as np
from astrbot import logger
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from ..base import BaseVecDB, Result
from .document_storage import DocumentStorage from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
class FaissVecDB(BaseVecDB): class FaissVecDB(BaseVecDB):
"""A class to represent a vector database.""" """
A class to represent a vector database.
"""
def __init__( def __init__(
self, self,
@@ -26,8 +25,7 @@ class FaissVecDB(BaseVecDB):
self.embedding_provider = embedding_provider self.embedding_provider = embedding_provider
self.document_storage = DocumentStorage(doc_store_path) self.document_storage = DocumentStorage(doc_store_path)
self.embedding_storage = EmbeddingStorage( self.embedding_storage = EmbeddingStorage(
embedding_provider.get_dim(), embedding_provider.get_dim(), index_store_path
index_store_path,
) )
self.embedding_provider = embedding_provider self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider self.rerank_provider = rerank_provider
@@ -36,69 +34,28 @@ class FaissVecDB(BaseVecDB):
await self.document_storage.initialize() await self.document_storage.initialize()
async def insert( async def insert(
self, self, content: str, metadata: dict | None = None, id: str | None = None
content: str,
metadata: dict | None = None,
id: str | None = None,
) -> int: ) -> int:
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。""" """
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
metadata = metadata or {} metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
vector = await self.embedding_provider.get_embedding(content) vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32) vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 使用 DocumentStorage 的方法插入文档 # 插入向量到 FAISS
int_id = await self.document_storage.insert_document(str_id, content, metadata) await self.embedding_storage.insert(vector, int_id)
return int_id
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[int]:
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
metadatas = metadatas or [{} for _ in contents]
ids = ids or [str(uuid.uuid4()) for _ in contents]
start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
end = time.time()
logger.debug(
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.",
)
# 使用 DocumentStorage 的批量插入方法
int_ids = await self.document_storage.insert_documents_batch(
ids,
contents,
metadatas,
)
# 批量插入向量到 FAISS
vectors_array = np.array(vectors).astype("float32")
await self.embedding_storage.insert_batch(vectors_array, int_ids)
return int_ids
async def retrieve( async def retrieve(
self, self,
@@ -108,7 +65,8 @@ class FaissVecDB(BaseVecDB):
rerank: bool = False, rerank: bool = False,
metadata_filters: dict | None = None, metadata_filters: dict | None = None,
) -> list[Result]: ) -> list[Result]:
"""搜索最相似的文档。 """
搜索最相似的文档。
Args: Args:
query (str): 查询文本 query (str): 查询文本
@@ -119,7 +77,6 @@ class FaissVecDB(BaseVecDB):
Returns: Returns:
List[Result]: 查询结果 List[Result]: 查询结果
""" """
embedding = await self.embedding_provider.get_embedding(query) embedding = await self.embedding_provider.get_embedding(query)
scores, indices = await self.embedding_storage.search( scores, indices = await self.embedding_storage.search(
@@ -132,8 +89,7 @@ class FaissVecDB(BaseVecDB):
scores[0] = 1.0 - (scores[0] / 2.0) scores[0] = 1.0 - (scores[0] / 2.0)
# NOTE: maybe the size is less than k. # NOTE: maybe the size is less than k.
fetched_docs = await self.document_storage.get_documents( fetched_docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters or {}, metadata_filters=metadata_filters or {}, ids=indices[0]
ids=indices[0],
) )
if not fetched_docs: if not fetched_docs:
return [] return []
@@ -154,51 +110,31 @@ class FaissVecDB(BaseVecDB):
documents = [doc.data["text"] for doc in top_k_results] documents = [doc.data["text"] for doc in top_k_results]
reranked_results = await self.rerank_provider.rerank(query, documents) reranked_results = await self.rerank_provider.rerank(query, documents)
reranked_results = sorted( reranked_results = sorted(
reranked_results, reranked_results, key=lambda x: x.relevance_score, reverse=True
key=lambda x: x.relevance_score,
reverse=True,
) )
top_k_results = [ top_k_results = [
top_k_results[reranked_result.index] top_k_results[reranked_result.index] for reranked_result in reranked_results
for reranked_result in reranked_results
] ]
return top_k_results return top_k_results
async def delete(self, doc_id: str): async def delete(self, doc_id: int):
"""删除一条文档块chunk""" """
# 获得对应的 int id 删除一条文档
result = await self.document_storage.get_document_by_doc_id(doc_id) """
int_id = result["id"] if result else None await self.document_storage.connection.execute(
if int_id is None: "DELETE FROM documents WHERE doc_id = ?", (doc_id,)
return )
await self.document_storage.connection.commit()
# 使用 DocumentStorage 的删除方法
await self.document_storage.delete_document_by_doc_id(doc_id)
await self.embedding_storage.delete([int_id])
async def close(self): async def close(self):
await self.document_storage.close() await self.document_storage.close()
async def count_documents(self, metadata_filter: dict | None = None) -> int: async def count_documents(self) -> int:
"""计算文档数量
Args:
metadata_filter (dict | None): 元数据过滤器
""" """
count = await self.document_storage.count_documents( 计算文档数量
metadata_filters=metadata_filter or {}, """
) async with self.document_storage.connection.cursor() as cursor:
return count await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
async def delete_documents(self, metadata_filters: dict): return count[0] if count else 0
"""根据元数据过滤器删除文档"""
docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters,
offset=None,
limit=None,
)
doc_ids: list[int] = [doc["id"] for doc in docs]
await self.embedding_storage.delete(doc_ids)
await self.document_storage.delete_documents(metadata_filters=metadata_filters)

View File

@@ -1,4 +1,5 @@
"""事件总线, 用于处理事件的分发和处理 """
事件总线, 用于处理事件的分发和处理
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理 事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑 其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
@@ -12,12 +13,10 @@ class:
import asyncio import asyncio
from asyncio import Queue from asyncio import Queue
from astrbot.core import logger
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.pipeline.scheduler import PipelineScheduler from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent from .platform import AstrMessageEvent
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class EventBus: class EventBus:
@@ -47,15 +46,14 @@ class EventBus:
Args: Args:
event (AstrMessageEvent): 事件对象 event (AstrMessageEvent): 事件对象
""" """
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name(): if event.get_sender_name():
logger.info( logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}", f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
) )
# 没有发送者名称: [平台名] 发送者ID: 消息概要 # 没有发送者名称: [平台名] 发送者ID: 消息概要
else: else:
logger.info( logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}", f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
) )

View File

@@ -1,9 +1,9 @@
import asyncio import asyncio
import os import os
import platform
import time
import uuid import uuid
from urllib.parse import unquote, urlparse import time
from urllib.parse import urlparse, unquote
import platform
class FileTokenService: class FileTokenService:
@@ -23,12 +23,7 @@ class FileTokenService:
for token in expired_tokens: for token in expired_tokens:
self.staged_files.pop(token, None) self.staged_files.pop(token, None)
async def check_token_expired(self, file_token: str) -> bool: async def register_file(self, file_path: str, timeout: float = None) -> str:
async with self.lock:
await self._cleanup_expired_tokens()
return file_token not in self.staged_files
async def register_file(self, file_path: str, timeout: float | None = None) -> str:
"""向令牌服务注册一个文件。 """向令牌服务注册一个文件。
Args: Args:
@@ -40,8 +35,8 @@ class FileTokenService:
Raises: Raises:
FileNotFoundError: 当路径不存在时抛出 FileNotFoundError: 当路径不存在时抛出
""" """
# 处理 file:/// # 处理 file:///
try: try:
parsed_uri = urlparse(file_path) parsed_uri = urlparse(file_path)
@@ -61,7 +56,7 @@ class FileTokenService:
if not os.path.exists(local_path): if not os.path.exists(local_path):
raise FileNotFoundError( raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})", f"文件不存在: {local_path} (原始输入: {file_path})"
) )
file_token = str(uuid.uuid4()) file_token = str(uuid.uuid4())
@@ -84,7 +79,6 @@ class FileTokenService:
Raises: Raises:
KeyError: 当令牌不存在或已过期时抛出 KeyError: 当令牌不存在或已过期时抛出
FileNotFoundError: 当文件本身已被删除时抛出 FileNotFoundError: 当文件本身已被删除时抛出
""" """
async with self.lock: async with self.lock:
await self._cleanup_expired_tokens() await self._cleanup_expired_tokens()

View File

@@ -1,4 +1,5 @@
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。 """
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
工作流程: 工作流程:
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期 1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
@@ -7,10 +8,10 @@
import asyncio import asyncio
import traceback import traceback
from astrbot.core import logger
from astrbot.core import LogBroker, logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard from astrbot.dashboard.server import AstrBotDashboard
@@ -21,7 +22,6 @@ class InitialLoader:
self.db = db self.db = db
self.logger = logger self.logger = logger
self.log_broker = log_broker self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self): async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
@@ -35,21 +35,13 @@ class InitialLoader:
core_task = core_lifecycle.start() core_task = core_lifecycle.start()
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard( self.dashboard_server = AstrBotDashboard(
core_lifecycle, core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
self.db,
core_lifecycle.dashboard_shutdown_event,
webui_dir,
) )
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
coro = self.dashboard_server.run()
if coro:
# 启动核心任务和仪表板服务器
task = asyncio.gather(core_task, coro)
else:
task = core_task
try: try:
await task # 整个AstrBot在这里运行 await task # 整个AstrBot在这里运行
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -1,9 +0,0 @@
"""文档分块模块"""
from .base import BaseChunker
from .fixed_size import FixedSizeChunker
__all__ = [
"BaseChunker",
"FixedSizeChunker",
]

View File

@@ -1,25 +0,0 @@
"""文档分块器基类
定义了文档分块处理的抽象接口。
"""
from abc import ABC, abstractmethod
class BaseChunker(ABC):
"""分块器基类
所有分块器都应该继承此类并实现 chunk 方法。
"""
@abstractmethod
async def chunk(self, text: str, **kwargs) -> list[str]:
"""将文本分块
Args:
text: 输入文本
Returns:
list[str]: 分块后的文本列表
"""

View File

@@ -1,59 +0,0 @@
"""固定大小分块器
按照固定的字符数将文本分块,支持重叠区域。
"""
from .base import BaseChunker
class FixedSizeChunker(BaseChunker):
"""固定大小分块器
按照固定的字符数分块,并支持块之间的重叠。
"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
"""初始化分块器
Args:
chunk_size: 块的大小(字符数)
chunk_overlap: 块之间的重叠字符数
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
async def chunk(self, text: str, **kwargs) -> list[str]:
"""固定大小分块
Args:
text: 输入文本
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
Returns:
list[str]: 分块后的文本列表
"""
chunk_size = kwargs.get("chunk_size", self.chunk_size)
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = start + chunk_size
chunk = text[start:end]
if chunk:
chunks.append(chunk)
# 移动窗口,保留重叠部分
start = end - chunk_overlap
# 防止无限循环: 如果重叠过大,直接移到end
if start >= end or chunk_overlap >= chunk_size:
start = end
return chunks

View File

@@ -1,161 +0,0 @@
from collections.abc import Callable
from .base import BaseChunker
class RecursiveCharacterChunker(BaseChunker):
def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 100,
length_function: Callable[[str], int] = len,
is_separator_regex: bool = False,
separators: list[str] | None = None,
):
"""初始化递归字符文本分割器
Args:
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
length_function: 计算文本长度的函数
is_separator_regex: 分隔符是否为正则表达式
separators: 用于分割文本的分隔符列表,按优先级排序
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
self.is_separator_regex = is_separator_regex
# 默认分隔符列表,按优先级从高到低
self.separators = separators or [
"\n\n", # 段落
"\n", # 换行
"", # 中文句子
"", # 中文逗号
". ", # 句子
", ", # 逗号分隔
" ", # 单词
"", # 字符
]
async def chunk(self, text: str, **kwargs) -> list[str]:
"""递归地将文本分割成块
Args:
text: 要分割的文本
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
Returns:
分割后的文本块列表
"""
if not text:
return []
overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
chunk_size = kwargs.get("chunk_size", self.chunk_size)
text_length = self.length_function(text)
if text_length <= chunk_size:
return [text]
for separator in self.separators:
if separator == "":
return self._split_by_character(text, chunk_size, overlap)
if separator in text:
splits = text.split(separator)
# 重新添加分隔符(除了最后一个片段)
splits = [s + separator for s in splits[:-1]] + [splits[-1]]
splits = [s for s in splits if s]
if len(splits) == 1:
continue
# 递归合并分割后的文本块
final_chunks = []
current_chunk = []
current_chunk_length = 0
for split in splits:
split_length = self.length_function(split)
# 如果单个分割部分已经超过了chunk_size需要递归分割
if split_length > chunk_size:
# 先处理当前积累的块
if current_chunk:
combined_text = "".join(current_chunk)
final_chunks.extend(
await self.chunk(
combined_text,
chunk_size=chunk_size,
chunk_overlap=overlap,
),
)
current_chunk = []
current_chunk_length = 0
# 递归分割过大的部分
final_chunks.extend(
await self.chunk(
split,
chunk_size=chunk_size,
chunk_overlap=overlap,
),
)
# 如果添加这部分会使当前块超过chunk_size
elif current_chunk_length + split_length > chunk_size:
# 合并当前块并添加到结果中
combined_text = "".join(current_chunk)
final_chunks.append(combined_text)
# 处理重叠部分
overlap_start = max(0, len(combined_text) - overlap)
if overlap_start > 0:
overlap_text = combined_text[overlap_start:]
current_chunk = [overlap_text, split]
current_chunk_length = (
self.length_function(overlap_text) + split_length
)
else:
current_chunk = [split]
current_chunk_length = split_length
else:
# 添加到当前块
current_chunk.append(split)
current_chunk_length += split_length
# 处理剩余的块
if current_chunk:
final_chunks.append("".join(current_chunk))
return final_chunks
return [text]
def _split_by_character(
self,
text: str,
chunk_size: int | None = None,
overlap: int | None = None,
) -> list[str]:
"""按字符级别分割文本
Args:
text: 要分割的文本
Returns:
分割后的文本块列表
"""
chunk_size = chunk_size or self.chunk_size
overlap = overlap or self.chunk_overlap
result = []
for i in range(0, len(text), chunk_size - overlap):
end = min(i + chunk_size, len(text))
result.append(text[i:end])
if end == len(text):
break
return result

View File

@@ -1,301 +0,0 @@
from contextlib import asynccontextmanager
from pathlib import Path
from sqlalchemy import delete, func, select, text, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import col, desc
from astrbot.core import logger
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
KBMedia,
KnowledgeBase,
)
class KBSQLiteDatabase:
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(BaseKBModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)",
),
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)",
),
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)",
),
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)",
),
)
await session.commit()
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
"""根据 ID 获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
"""根据名称获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(desc(KnowledgeBase.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KnowledgeBase.id)))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
"""根据 ID 获取文档"""
async with self.get_db() as session:
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self,
kb_id: str,
offset: int = 0,
limit: int = 100,
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.get_db() as session:
stmt = (
select(KBDocument)
.where(col(KBDocument.kb_id) == kb_id)
.offset(offset)
.limit(limit)
.order_by(desc(KBDocument.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KBDocument.id))).where(
col(KBDocument.kb_id) == kb_id,
)
result = await session.execute(stmt)
return result.scalar() or 0
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
async with self.get_db() as session:
stmt = (
select(KBDocument, KnowledgeBase)
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
.where(col(KBDocument.doc_id) == doc_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
return {
"document": row[0],
"knowledge_base": row[1],
}
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session, session.begin():
# 删除文档记录
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
await session.execute(delete_stmt)
await session.commit()
# 在 vec db 中删除相关向量
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
"""根据 ID 获取多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
"""更新知识库统计信息"""
chunk_cnt = await vec_db.count_documents()
async with self.get_db() as session, session.begin():
update_stmt = (
update(KnowledgeBase)
.where(col(KnowledgeBase.kb_id) == kb_id)
.values(
doc_count=select(func.count(col(KBDocument.id)))
.where(col(KBDocument.kb_id) == kb_id)
.scalar_subquery(),
chunk_count=chunk_cnt,
)
)
await session.execute(update_stmt)
await session.commit()

View File

@@ -1,361 +0,0 @@
import json
import uuid
from pathlib import Path
import aiofiles
from astrbot.core import logger
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from .chunking.base import BaseChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .models import KBDocument, KBMedia, KnowledgeBase
from .parsers.util import select_parser
class KBHelper:
vec_db: BaseVecDB
kb: KnowledgeBase
def __init__(
self,
kb_db: KBSQLiteDatabase,
kb: KnowledgeBase,
provider_manager: ProviderManager,
kb_root_dir: str,
chunker: BaseChunker,
):
self.kb_db = kb_db
self.kb = kb
self.prov_mgr = provider_manager
self.kb_root_dir = kb_root_dir
self.chunker = chunker
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self):
await self._ensure_vec_db()
async def get_ep(self) -> EmbeddingProvider:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
self.kb.embedding_provider_id,
) # type: ignore
if not ep:
raise ValueError(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider",
)
return ep
async def get_rp(self) -> RerankProvider | None:
if not self.kb.rerank_provider_id:
return None
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
self.kb.rerank_provider_id,
) # type: ignore
if not rp:
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider",
)
return rp
async def _ensure_vec_db(self) -> FaissVecDB:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep = await self.get_ep()
rp = await self.get_rp()
vec_db = FaissVecDB(
doc_store_path=str(self.kb_dir / "doc.db"),
index_store_path=str(self.kb_dir / "index.faiss"),
embedding_provider=ep,
rerank_provider=rp,
)
await vec_db.initialize()
self.vec_db = vec_db
return vec_db
async def delete_vec_db(self):
"""删除知识库的向量数据库和所有相关文件"""
import shutil
await self.terminate()
if self.kb_dir.exists():
shutil.rmtree(self.kb_dir)
async def terminate(self):
if self.vec_db:
await self.vec_db.close()
async def upload_document(
self,
file_name: str,
file_content: bytes,
file_type: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
Args:
progress_callback: 进度回调函数,接收参数 (stage, current, total)
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
- current: 当前进度
- total: 总数
"""
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
media_paths: list[Path] = []
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
# async with aiofiles.open(file_path, "wb") as f:
# await f.write(file_content)
try:
# 阶段1: 解析文档
if progress_callback:
await progress_callback("parsing", 0, 100)
parser = await select_parser(f".{file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
if progress_callback:
await progress_callback("parsing", 100, 100)
# 保存媒体文件
saved_media = []
for media_item in media_items:
media = await self._save_media(
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 阶段2: 分块
if progress_callback:
await progress_callback("chunking", 0, 100)
chunks_text = await self.chunker.chunk(
text_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
contents = []
metadatas = []
for idx, chunk_text in enumerate(chunks_text):
contents.append(chunk_text)
metadatas.append(
{
"kb_id": self.kb.kb_id,
"kb_doc_id": doc_id,
"chunk_index": idx,
},
)
if progress_callback:
await progress_callback("chunking", 100, 100)
# 阶段3: 生成向量(带进度回调)
async def embedding_progress_callback(current, total):
if progress_callback:
await progress_callback("embedding", current, total)
await self.vec_db.insert_batch(
contents=contents,
metadatas=metadatas,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=embedding_progress_callback,
)
# 保存文档的元数据
doc = KBDocument(
doc_id=doc_id,
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
# file_path=str(file_path),
file_path="",
chunk_count=len(chunks_text),
media_count=0,
)
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
vec_db: FaissVecDB = self.vec_db # type: ignore
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
await self.refresh_kb()
await self.refresh_document(doc_id)
return doc
except Exception as e:
logger.error(f"上传文档失败: {e}")
# if file_path.exists():
# file_path.unlink()
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
raise e
async def list_documents(
self,
offset: int = 0,
limit: int = 100,
) -> list[KBDocument]:
"""列出知识库的所有文档"""
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
return docs
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取单个文档"""
doc = await self.kb_db.get_document_by_id(doc_id)
return doc
async def delete_document(self, doc_id: str):
"""删除单个文档及其相关数据"""
await self.kb_db.delete_document_by_id(
doc_id=doc_id,
vec_db=self.vec_db, # type: ignore
)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
async def delete_chunk(self, chunk_id: str, doc_id: str):
"""删除单个文本块及其相关数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
await vec_db.delete(chunk_id)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
await self.refresh_document(doc_id)
async def refresh_kb(self):
if self.kb:
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
if kb:
self.kb = kb
async def refresh_document(self, doc_id: str) -> None:
"""更新文档的元数据"""
doc = await self.get_document(doc_id)
if not doc:
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
doc.chunk_count = chunk_count
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
await session.commit()
await session.refresh(doc)
async def get_chunks_by_doc_id(
self,
doc_id: str,
offset: int = 0,
limit: int = 100,
) -> list[dict]:
"""获取文档的所有块及其元数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
chunks = await vec_db.document_storage.get_documents(
metadata_filters={"kb_doc_id": doc_id},
offset=offset,
limit=limit,
)
result = []
for chunk in chunks:
chunk_md = json.loads(chunk["metadata"])
result.append(
{
"chunk_id": chunk["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": chunk_md["kb_id"],
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
"char_count": len(chunk["text"]),
},
)
return result
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
"""获取文档的块数量"""
vec_db: FaissVecDB = self.vec_db # type: ignore
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
return count
async def _save_media(
self,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=self.kb.kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media

View File

@@ -1,286 +0,0 @@
import traceback
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.manager import ProviderManager
# from .chunking.fixed_size import FixedSizeChunker
from .chunking.recursive import RecursiveCharacterChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .kb_helper import KBHelper
from .models import KnowledgeBase
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.rank_fusion import RankFusion
from .retrieval.sparse_retriever import SparseRetriever
FILES_PATH = "data/knowledge_base"
DB_PATH = Path(FILES_PATH) / "kb.db"
"""Knowledge Base storage root directory"""
CHUNKER = RecursiveCharacterChunker()
class KnowledgeBaseManager:
kb_db: KBSQLiteDatabase
retrieval_manager: RetrievalManager
def __init__(
self,
provider_manager: ProviderManager,
):
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
self.provider_manager = provider_manager
self._session_deleted_callback_registered = False
self.kb_insts: dict[str, KBHelper] = {}
async def initialize(self):
"""初始化知识库模块"""
try:
logger.info("正在初始化知识库模块...")
# 初始化数据库
await self._init_kb_database()
# 初始化检索管理器
sparse_retriever = SparseRetriever(self.kb_db)
rank_fusion = RankFusion(self.kb_db)
self.retrieval_manager = RetrievalManager(
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_db,
)
await self.load_kbs()
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self):
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
async def load_kbs(self):
"""加载所有知识库实例"""
kb_records = await self.kb_db.list_kbs()
for record in kb_records:
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=record,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
await kb_helper.initialize()
self.kb_insts[record.kb_id] = kb_helper
async def create_kb(
self,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper:
"""创建新的知识库实例"""
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
)
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=kb,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
await kb_helper.initialize()
self.kb_insts[kb.kb_id] = kb_helper
return kb_helper
async def get_kb(self, kb_id: str) -> KBHelper | None:
"""获取知识库实例"""
if kb_id in self.kb_insts:
return self.kb_insts[kb_id]
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
"""通过名称获取知识库实例"""
for kb_helper in self.kb_insts.values():
if kb_helper.kb.kb_name == kb_name:
return kb_helper
return None
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return False
await kb_helper.delete_vec_db()
async with self.kb_db.get_db() as session:
await session.delete(kb_helper.kb)
await session.commit()
self.kb_insts.pop(kb_id, None)
return True
async def list_kbs(self) -> list[KnowledgeBase]:
"""列出所有知识库实例"""
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
return kbs
async def update_kb(
self,
kb_id: str,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper | None:
"""更新知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return None
kb = kb_helper.kb
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
return kb_helper
async def retrieve(
self,
query: str,
kb_names: list[str],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> dict | None:
"""从指定知识库中检索相关内容"""
kb_ids = []
kb_id_helper_map = {}
for kb_name in kb_names:
if kb_helper := await self.get_kb_by_name(kb_name):
kb_ids.append(kb_helper.kb.kb_id)
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
if not kb_ids:
return {}
results = await self.retrieval_manager.retrieve(
query=query,
kb_ids=kb_ids,
kb_id_helper_map=kb_id_helper_map,
top_k_fusion=top_k_fusion,
top_m_final=top_m_final,
)
if not results:
return None
context_text = self._format_context(results)
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
"char_count": r.metadata.get("char_count", 0),
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
def _format_context(self, results: list[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
async def terminate(self):
"""终止所有知识库实例,关闭数据库连接"""
for kb_id, kb_helper in self.kb_insts.items():
try:
await kb_helper.terminate()
except Exception as e:
logger.error(f"关闭知识库 {kb_id} 失败: {e}")
self.kb_insts.clear()
# 关闭元数据数据库
if hasattr(self, "kb_db") and self.kb_db:
try:
await self.kb_db.close()
except Exception as e:
logger.error(f"关闭知识库元数据数据库失败: {e}")

View File

@@ -1,120 +0,0 @@
import uuid
from datetime import datetime, timezone
from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint
class BaseKBModel(SQLModel, table=False):
metadata = MetaData()
class KnowledgeBase(BaseKBModel, table=True):
"""知识库表
存储知识库的基本信息和统计数据。
"""
__tablename__ = "knowledge_bases" # type: ignore
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
)
kb_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_name: str = Field(max_length=100, nullable=False)
description: str | None = Field(default=None, sa_type=Text)
emoji: str | None = Field(default="📚", max_length=10)
embedding_provider_id: str | None = Field(default=None, max_length=100)
rerank_provider_id: str | None = Field(default=None, max_length=100)
# 分块配置参数
chunk_size: int | None = Field(default=512, nullable=True)
chunk_overlap: int | None = Field(default=50, nullable=True)
# 检索配置参数
top_k_dense: int | None = Field(default=50, nullable=True)
top_k_sparse: int | None = Field(default=50, nullable=True)
top_m_final: int | None = Field(default=5, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
doc_count: int = Field(default=0, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"kb_name",
name="uix_kb_name",
),
)
class KBDocument(BaseKBModel, table=True):
"""文档表
存储上传到知识库的文档元数据。
"""
__tablename__ = "kb_documents" # type: ignore
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
)
doc_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_id: str = Field(max_length=36, nullable=False, index=True)
doc_name: str = Field(max_length=255, nullable=False)
file_type: str = Field(max_length=20, nullable=False)
file_size: int = Field(nullable=False)
file_path: str = Field(max_length=512, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
media_count: int = Field(default=0, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class KBMedia(BaseKBModel, table=True):
"""多媒体资源表
存储从文档中提取的图片、视频等多媒体资源。
"""
__tablename__ = "kb_media" # type: ignore
id: int | None = Field(
primary_key=True,
sa_column_kwargs={"autoincrement": True},
default=None,
)
media_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
doc_id: str = Field(max_length=36, nullable=False, index=True)
kb_id: str = Field(max_length=36, nullable=False, index=True)
media_type: str = Field(max_length=20, nullable=False)
file_name: str = Field(max_length=255, nullable=False)
file_path: str = Field(max_length=512, nullable=False)
file_size: int = Field(nullable=False)
mime_type: str = Field(max_length=100, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))

View File

@@ -1,13 +0,0 @@
"""文档解析器模块"""
from .base import BaseParser, MediaItem, ParseResult
from .pdf_parser import PDFParser
from .text_parser import TextParser
__all__ = [
"BaseParser",
"MediaItem",
"PDFParser",
"ParseResult",
"TextParser",
]

View File

@@ -1,51 +0,0 @@
"""文档解析器基类和数据结构
定义了文档解析器的抽象接口和相关数据类。
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class MediaItem:
"""多媒体项
表示从文档中提取的多媒体资源。
"""
media_type: str # image, video
file_name: str
content: bytes
mime_type: str
@dataclass
class ParseResult:
"""解析结果
包含解析后的文本内容和提取的多媒体资源。
"""
text: str
media: list[MediaItem]
class BaseParser(ABC):
"""文档解析器基类
所有文档解析器都应该继承此类并实现 parse 方法。
"""
@abstractmethod
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文档
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果
"""

View File

@@ -1,26 +0,0 @@
import io
import os
from markitdown_no_magika import MarkItDown, StreamInfo
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
ParseResult,
)
class MarkitdownParser(BaseParser):
"""解析 docx, xls, xlsx 格式"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
md = MarkItDown(enable_plugins=False)
bio = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=os.path.splitext(file_name)[1].lower(),
filename=file_name,
)
result = md.convert(bio, stream_info=stream_info)
return ParseResult(
text=result.markdown,
media=[],
)

View File

@@ -1,101 +0,0 @@
"""PDF 文件解析器
支持解析 PDF 文件中的文本和图片资源。
"""
import io
from pypdf import PdfReader
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
MediaItem,
ParseResult,
)
class PDFParser(BaseParser):
"""PDF 文档解析器
提取 PDF 中的文本内容和嵌入的图片资源。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析 PDF 文件
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 包含文本和图片的解析结果
"""
pdf_file = io.BytesIO(file_content)
reader = PdfReader(pdf_file)
text_parts = []
media_items = []
# 提取文本
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
# 提取图片
image_counter = 0
for page_num, page in enumerate(reader.pages):
try:
# 安全检查 Resources
if "/Resources" not in page:
continue
resources = page["/Resources"]
if not resources or "/XObject" not in resources: # type: ignore
continue
xobjects = resources["/XObject"].get_object() # type: ignore
if not xobjects:
continue
for obj_name in xobjects:
try:
obj = xobjects[obj_name]
if obj.get("/Subtype") != "/Image":
continue
# 提取图片数据
image_data = obj.get_data()
# 确定格式
filter_type = obj.get("/Filter", "")
if filter_type == "/DCTDecode":
ext = "jpg"
mime_type = "image/jpeg"
elif filter_type == "/FlateDecode":
ext = "png"
mime_type = "image/png"
else:
ext = "png"
mime_type = "image/png"
image_counter += 1
media_items.append(
MediaItem(
media_type="image",
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
content=image_data,
mime_type=mime_type,
),
)
except Exception:
# 单个图片提取失败不影响整体
continue
except Exception:
# 页面处理失败不影响其他页面
continue
full_text = "\n\n".join(text_parts)
return ParseResult(text=full_text, media=media_items)

View File

@@ -1,42 +0,0 @@
"""文本文件解析器
支持解析 TXT 和 Markdown 文件。
"""
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
class TextParser(BaseParser):
"""TXT/MD 文本解析器
支持多种字符编码的自动检测。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文本文件
尝试使用多种编码解析文件内容。
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果,不包含多媒体资源
Raises:
ValueError: 如果无法解码文件
"""
# 尝试多种编码
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
try:
text = file_content.decode(encoding)
break
except UnicodeDecodeError:
continue
else:
raise ValueError(f"无法解码文件: {file_name}")
# 文本文件无多媒体资源
return ParseResult(text=text, media=[])

View File

@@ -1,13 +0,0 @@
from .base import BaseParser
async def select_parser(ext: str) -> BaseParser:
if ext in {".md", ".txt", ".markdown", ".xlsx", ".docx", ".xls"}:
from .markitdown_parser import MarkitdownParser
return MarkitdownParser()
if ext == ".pdf":
from .pdf_parser import PDFParser
return PDFParser()
raise ValueError(f"暂时不支持的文件格式: {ext}")

View File

@@ -1,14 +0,0 @@
"""检索模块"""
from .manager import RetrievalManager, RetrievalResult
from .rank_fusion import FusedResult, RankFusion
from .sparse_retriever import SparseResult, SparseRetriever
__all__ = [
"FusedResult",
"RankFusion",
"RetrievalManager",
"RetrievalResult",
"SparseResult",
"SparseRetriever",
]

View File

@@ -1,767 +0,0 @@
———
》),
)÷(1-
”,
)、
:
&
*
一一
~~~~
.
.一
./
--
=″
[⑤]]
[①D]
ng昉
//
[②e]
[②g]
}
,也
[①⑥]
[②B]
[①a]
[④a]
[①③]
[③h]
③]
[②b]
×××
[①⑧]
[⑤b]
[②c]
[④b]
[②③]
[③a]
[④c]
[①⑤]
[①⑦]
[①g]
∈[
[①⑨]
[①④]
[①c]
[②f]
[②⑧]
[②①]
[①C]
[③c]
[③g]
[②⑤]
[②②]
一.
[①h]
.数
[①B]
数/
[①i]
[③e]
[①①]
[④d]
[④e]
[③b]
[⑤a]
[①A]
[②⑧]
[②⑦]
[①d]
[②j]
://
′∈
[②④
[⑤e]
...
...................
…………………………………………………③
[③F]
[①o]
]∧′=[
∪φ∈
②c
[③①]
[①E]
Ψ
.日
[②d]
[②
[②⑦]
[②②]
[③e]
[①i]
[①B]
[①h]
[①d]
[①g]
[①②]
[②a]
[⑩]
[①e]
[②h]
[②⑥]
[③d]
[②⑩]
元/吨
[②⑩]
[①]
::
[②]
[③]
[④]
[⑤]
[⑥]
[⑦]
[⑧]
[⑨]
……
——
?
,
'
?
·
———
──
?
<
>
[
]
(
)
-
+
×
/
В
"
;
#
@
γ
μ
φ
φ.
×
Δ
sub
exp
sup
sub
Lex
+ξ
-β
<±
<Δ
<λ
<φ
=
=☆
>λ
_
~±
[⑤f]
[⑤d]
[②i]
[②G]
[①f]
......
[③⑩]
第二
一番
一直
一个
一些
许多
有的是
也就是说
末##末
哎呀
哎哟
俺们
按照
吧哒
罢了
本着
比方
比如
鄙人
彼此
别的
别说
并且
不比
不成
不单
不但
不独
不管
不光
不过
不仅
不拘
不论
不怕
不然
不如
不特
不惟
不问
不只
朝着
趁着
除此之外
除非
除了
此间
此外
从而
但是
当着
的话
等等
叮咚
对于
多少
而况
而且
而是
而外
而言
而已
尔后
反过来
反过来说
反之
非但
非徒
否则
嘎登
各个
各位
各种
各自
根据
故此
固然
关于
果然
果真
哈哈
何处
何况
何时
哼唷
呼哧
还是
还有
换句话说
换言之
或是
或者
极了
及其
及至
即便
即或
即令
即若
即使
几时
既然
既是
继而
加之
假如
假若
假使
鉴于
较之
接着
结果
紧接着
进而
尽管
经过
就是
就是说
具体地说
具体说来
开始
开外
可见
可是
可以
况且
来着
例如
连同
两者
另外
另一方面
慢说
漫说
每当
莫若
某个
某些
哪边
哪儿
哪个
哪里
哪年
哪怕
哪天
哪些
哪样
那边
那儿
那个
那会儿
那里
那么
那么些
那么样
那时
那些
那样
乃至
你们
宁可
宁肯
宁愿
啪达
旁人
凭借
其次
其二
其他
其它
其一
其余
其中
起见
起见
岂但
恰恰相反
前后
前者
然而
然后
然则
人家
任何
任凭
如此
如果
如何
如其
如若
如上所述
若非
若是
上下
尚且
设若
设使
甚而
甚么
甚至
省得
时候
什么
什么样
使得
是的
首先
谁知
顺着
似的
虽然
虽说
虽则
随着
所以
他们
他人
它们
她们
倘或
倘然
倘若
倘使
通过
同时
万一
为何
为了
为什么
为着
嗡嗡
我们
呜呼
乌乎
无论
无宁
毋宁
相对而言
向着
沿
沿着
要不
要不然
要不是
要么
要是
也罢
也好
一般
一旦
一方面
一来
一切
一样
一则
依照
以便
以及
以免
以至
以至于
以致
抑或
因此
因而
因为
由此可见
由于
有的
有关
有些
于是
于是乎
与此同时
与否
与其
越是
云云
再说
再者
在下
咱们
怎么
怎么办
怎么样
怎样
照着
这边
这儿
这个
这会儿
这就是说
这里
这么
这么点儿
这么些
这么样
这时
这些
这样
正如
之类
之所以
之一
只是
只限
只要
只有
至于
诸位
着呢
自从
自个儿
自各儿
自己
自家
自身
综上所述
总的来看
总的来说
总的说来
总而言之
总之
纵令
纵然
纵使
遵照
作为
喔唷

View File

@@ -1,276 +0,0 @@
"""检索管理器
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
"""
import time
from dataclasses import dataclass
from astrbot import logger
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from ..kb_helper import KBHelper
@dataclass
class RetrievalResult:
"""检索结果"""
chunk_id: str
doc_id: str
doc_name: str
kb_id: str
kb_name: str
content: str
score: float
metadata: dict
class RetrievalManager:
"""检索管理器
职责:
- 协调稠密检索、稀疏检索和 Rerank
- 结果融合和排序
"""
def __init__(
self,
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBSQLiteDatabase,
):
"""初始化检索管理器
Args:
vec_db_factory: 向量数据库工厂
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
"""
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
async def retrieve(
self,
query: str,
kb_ids: list[str],
kb_id_helper_map: dict[str, KBHelper],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> list[RetrievalResult]:
"""混合检索
流程:
1. 稠密检索 (向量相似度)
2. 稀疏检索 (BM25)
3. 结果融合 (RRF)
4. Rerank 重排序
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_m_final: 最终返回数量
enable_rerank: 是否启用 Rerank
Returns:
List[RetrievalResult]: 检索结果列表
"""
if not kb_ids:
return []
kb_options: dict = {}
new_kb_ids = []
for kb_id in kb_ids:
kb_helper = kb_id_helper_map.get(kb_id)
if kb_helper:
kb = kb_helper.kb
kb_options[kb_id] = {
"top_k_dense": kb.top_k_dense or 50,
"top_k_sparse": kb.top_k_sparse or 50,
"top_m_final": kb.top_m_final or 5,
"vec_db": kb_helper.vec_db,
"rerank_provider_id": kb.rerank_provider_id,
}
new_kb_ids.append(kb_id)
else:
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
kb_ids = new_kb_ids
# 1. 稠密检索
time_start = time.time()
dense_results = await self._dense_retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.",
)
# 2. 稀疏检索
time_start = time.time()
sparse_results = await self.sparse_retriever.retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.",
)
# 3. 结果融合
time_start = time.time()
fused_results = await self.rank_fusion.fuse(
dense_results=dense_results,
sparse_results=sparse_results,
top_k=top_k_fusion,
)
time_end = time.time()
logger.debug(
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
)
# 4. 转换为 RetrievalResult (获取元数据)
retrieval_results = []
for fr in fused_results:
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
if metadata_dict:
retrieval_results.append(
RetrievalResult(
chunk_id=fr.chunk_id,
doc_id=fr.doc_id,
doc_name=metadata_dict["document"].doc_name,
kb_id=fr.kb_id,
kb_name=metadata_dict["knowledge_base"].kb_name,
content=fr.content,
score=fr.score,
metadata={
"chunk_index": fr.chunk_index,
"char_count": len(fr.content),
},
),
)
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
and vec_db.rerank_provider
and rerank_pi
and rerank_pi == vec_db.rerank_provider.meta().id
):
first_rerank = vec_db.rerank_provider
break
if first_rerank and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=first_rerank,
)
return retrieval_results[:top_m_final]
async def _dense_retrieve(
self,
query: str,
kb_ids: list[str],
kb_options: dict,
):
"""稠密检索 (向量相似度)
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
Returns:
List[Result]: 检索结果列表
"""
all_results: list[Result] = []
for kb_id in kb_ids:
if kb_id not in kb_options:
continue
try:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
dense_k = int(kb_options[kb_id]["top_k_dense"])
vec_results = await vec_db.retrieve(
query=query,
k=dense_k,
fetch_k=dense_k * 2,
rerank=False, # 稠密检索阶段不进行 rerank
metadata_filters={"kb_id": kb_id},
)
all_results.extend(vec_results)
except Exception as e:
from astrbot.core import logger
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
continue
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)
# return all_results[: len(all_results) // len(kb_ids)]
return all_results
async def _rerank(
self,
query: str,
results: list[RetrievalResult],
top_k: int,
rerank_provider: RerankProvider,
) -> list[RetrievalResult]:
"""Rerank 重排序
Args:
query: 查询文本
results: 检索结果列表
top_k: 返回结果数量
Returns:
List[RetrievalResult]: 重排序后的结果列表
"""
if not results:
return []
# 准备文档列表
docs = [r.content for r in results]
# 调用 Rerank Provider
rerank_results = await rerank_provider.rerank(
query=query,
documents=docs,
)
# 更新分数并重新排序
reranked_list = []
for rerank_result in rerank_results:
idx = rerank_result.index
if idx < len(results):
result = results[idx]
result.score = rerank_result.relevance_score
reranked_list.append(result)
reranked_list.sort(key=lambda x: x.score, reverse=True)
return reranked_list[:top_k]

View File

@@ -1,142 +0,0 @@
"""检索结果融合器
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
"""
import json
from dataclasses import dataclass
from astrbot.core.db.vec_db.base import Result
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
@dataclass
class FusedResult:
"""融合后的检索结果"""
chunk_id: str
chunk_index: int
doc_id: str
kb_id: str
content: str
score: float
class RankFusion:
"""检索结果融合器
职责:
- 融合稠密检索和稀疏检索的结果
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
"""初始化结果融合器
Args:
kb_db: 知识库数据库实例
k: RRF 参数,用于平滑排名
"""
self.kb_db = kb_db
self.k = k
async def fuse(
self,
dense_results: list[Result],
sparse_results: list[SparseResult],
top_k: int = 20,
) -> list[FusedResult]:
"""融合稠密和稀疏检索结果
RRF 公式:
score(doc) = sum(1 / (k + rank_i))
Args:
dense_results: 稠密检索结果
sparse_results: 稀疏检索结果
top_k: 返回结果数量
Returns:
List[FusedResult]: 融合后的结果列表
"""
# 1. 构建排名映射
dense_ranks = {
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
} # 这里的 doc_id 实际上是 chunk_id
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
# 2. 收集所有唯一的 ID
# 需要统一为 chunk_id
all_chunk_ids = set()
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
# 处理稀疏检索结果
for r in sparse_results:
all_chunk_ids.add(r.chunk_id)
chunk_id_to_sparse[r.chunk_id] = r
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
for r in dense_results:
vec_doc_id = r.data["doc_id"]
all_chunk_ids.add(vec_doc_id)
vec_doc_id_to_dense[vec_doc_id] = r
# 3. 计算 RRF 分数
rrf_scores: dict[str, float] = {}
for identifier in all_chunk_ids:
score = 0.0
# 来自稠密检索的贡献
if identifier in dense_ranks:
score += 1.0 / (self.k + dense_ranks[identifier])
# 来自稀疏检索的贡献
if identifier in sparse_ranks:
score += 1.0 / (self.k + sparse_ranks[identifier])
rrf_scores[identifier] = score
# 4. 排序
sorted_ids = sorted(
rrf_scores.keys(),
key=lambda cid: rrf_scores[cid],
reverse=True,
)[:top_k]
# 5. 构建融合结果
fused_results = []
for identifier in sorted_ids:
# 优先从稀疏检索获取完整信息
if identifier in chunk_id_to_sparse:
sr = chunk_id_to_sparse[identifier]
fused_results.append(
FusedResult(
chunk_id=sr.chunk_id,
chunk_index=sr.chunk_index,
doc_id=sr.doc_id,
kb_id=sr.kb_id,
content=sr.content,
score=rrf_scores[identifier],
),
)
elif identifier in vec_doc_id_to_dense:
# 从向量检索获取信息,需要从数据库获取块的详细信息
vec_result = vec_doc_id_to_dense[identifier]
chunk_md = json.loads(vec_result.data["metadata"])
fused_results.append(
FusedResult(
chunk_id=identifier,
chunk_index=chunk_md["chunk_index"],
doc_id=chunk_md["kb_doc_id"],
kb_id=chunk_md["kb_id"],
content=vec_result.data["text"],
score=rrf_scores[identifier],
),
)
return fused_results

View File

@@ -1,136 +0,0 @@
"""稀疏检索器
使用 BM25 算法进行基于关键词的文档检索
"""
import json
import os
from dataclasses import dataclass
import jieba
from rank_bm25 import BM25Okapi
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
@dataclass
class SparseResult:
"""稀疏检索结果"""
chunk_index: int
chunk_id: str
doc_id: str
kb_id: str
content: str
score: float
class SparseRetriever:
"""BM25 稀疏检索器
职责:
- 基于关键词的文档检索
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化稀疏检索器
Args:
kb_db: 知识库数据库实例
"""
self.kb_db = kb_db
self._index_cache = {} # 缓存 BM25 索引
with open(
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
encoding="utf-8",
) as f:
self.hit_stopwords = {
word.strip() for word in set(f.read().splitlines()) if word.strip()
}
async def retrieve(
self,
query: str,
kb_ids: list[str],
kb_options: dict,
) -> list[SparseResult]:
"""执行稀疏检索
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
kb_options: 每个知识库的检索选项
Returns:
List[SparseResult]: 检索结果列表
"""
# 1. 获取所有相关块
top_k_sparse = 0
chunks = []
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
if not vec_db:
continue
result = await vec_db.document_storage.get_documents(
metadata_filters={},
limit=None,
offset=None,
)
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
result = [
{
"chunk_id": doc["doc_id"],
"chunk_index": chunk_md["chunk_index"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": kb_id,
"text": doc["text"],
}
for doc, chunk_md in zip(result, chunk_mds)
]
chunks.extend(result)
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
if not chunks:
return []
# 2. 准备文档和索引
corpus = [chunk["text"] for chunk in chunks]
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
tokenized_corpus = [
[word for word in doc if word not in self.hit_stopwords]
for doc in tokenized_corpus
]
# 3. 构建 BM25 索引
bm25 = BM25Okapi(tokenized_corpus)
# 4. 执行检索
tokenized_query = list(jieba.cut(query))
tokenized_query = [
word for word in tokenized_query if word not in self.hit_stopwords
]
scores = bm25.get_scores(tokenized_query)
# 5. 排序并返回 Top-K
results = []
for idx, score in enumerate(scores):
chunk = chunks[idx]
results.append(
SparseResult(
chunk_id=chunk["chunk_id"],
chunk_index=chunk["chunk_index"],
doc_id=chunk["doc_id"],
kb_id=chunk["kb_id"],
content=chunk["text"],
score=float(score),
),
)
results.sort(key=lambda x: x.score, reverse=True)
# return results[: len(results) // len(kb_ids)]
return results[:top_k_sparse]

View File

@@ -1,4 +1,5 @@
"""日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能 """
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const: const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量 CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
@@ -20,14 +21,14 @@ function:
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流 4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
""" """
import asyncio
import logging import logging
import colorlog
import asyncio
import os import os
import sys import sys
from asyncio import Queue
from collections import deque from collections import deque
from asyncio import Queue
import colorlog from typing import List
# 日志缓存大小 # 日志缓存大小
CACHED_SIZE = 200 CACHED_SIZE = 200
@@ -51,7 +52,6 @@ def is_plugin_path(pathname):
Returns: Returns:
bool: 如果路径来自插件目录,则返回 True否则返回 False bool: 如果路径来自插件目录,则返回 True否则返回 False
""" """
if not pathname: if not pathname:
return False return False
@@ -68,7 +68,6 @@ def get_short_level_name(level_name):
Returns: Returns:
str: 四个字母的日志级别缩写 str: 四个字母的日志级别缩写
""" """
level_map = { level_map = {
"DEBUG": "DBUG", "DEBUG": "DBUG",
@@ -88,14 +87,13 @@ class LogBroker:
def __init__(self): def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: list[Queue] = [] # 订阅者列表 self.subscribers: List[Queue] = [] # 订阅者列表
def register(self) -> Queue: def register(self) -> Queue:
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列 """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns: Returns:
Queue: 订阅者的队列, 可用于接收日志消息 Queue: 订阅者的队列, 可用于接收日志消息
""" """
q = Queue(maxsize=CACHED_SIZE + 10) q = Queue(maxsize=CACHED_SIZE + 10)
self.subscribers.append(q) self.subscribers.append(q)
@@ -106,7 +104,6 @@ class LogBroker:
Args: Args:
q (Queue): 需要取消订阅的队列 q (Queue): 需要取消订阅的队列
""" """
self.subscribers.remove(q) self.subscribers.remove(q)
@@ -116,7 +113,6 @@ class LogBroker:
Args: Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容. log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"} example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
""" """
self.log_cache.append(log_entry) self.log_cache.append(log_entry)
for q in self.subscribers: for q in self.subscribers:
@@ -142,7 +138,6 @@ class LogQueueHandler(logging.Handler):
Args: Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息 record (logging.LogRecord): 日志记录对象, 包含日志信息
""" """
log_entry = self.format(record) log_entry = self.format(record)
self.log_broker.publish( self.log_broker.publish(
@@ -150,7 +145,7 @@ class LogQueueHandler(logging.Handler):
"level": record.levelname, "level": record.levelname,
"time": record.asctime, "time": record.asctime,
"data": log_entry, "data": log_entry,
}, }
) )
@@ -169,7 +164,6 @@ class LogManager:
Returns: Returns:
logging.Logger: 返回配置好的日志记录器 logging.Logger: 返回配置好的日志记录器
""" """
logger = logging.getLogger(log_name) logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置 # 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
@@ -177,10 +171,10 @@ class LogManager:
return logger return logger
# 如果logger没有处理器 # 如果logger没有处理器
console_handler = logging.StreamHandler( console_handler = logging.StreamHandler(
sys.stdout, sys.stdout
) # 创建一个StreamHandler用于控制台输出 ) # 创建一个StreamHandler用于控制台输出
console_handler.setLevel( console_handler.setLevel(
logging.DEBUG, logging.DEBUG
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG ) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息 # 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
@@ -201,8 +195,7 @@ class LogManager:
class FileNameFilter(logging.Filter): class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式 """文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式 例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py # 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record): def filter(self, record):
@@ -238,7 +231,6 @@ class LogManager:
Args: Args:
logger (logging.Logger): 日志记录器 logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息 log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
""" """
handler = LogQueueHandler(log_broker) handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG) handler.setLevel(logging.DEBUG)
@@ -248,7 +240,7 @@ class LogManager:
# 为队列处理器设置相同格式的formatter # 为队列处理器设置相同格式的formatter
handler.setFormatter( handler.setFormatter(
logging.Formatter( logging.Formatter(
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s", "[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
), )
) )
logger.addHandler(handler) logger.addHandler(handler)

View File

@@ -1,4 +1,5 @@
"""MIT License """
MIT License
Copyright (c) 2021 Lxns-Network Copyright (c) 2021 Lxns-Network
@@ -25,6 +26,7 @@ import asyncio
import base64 import base64
import json import json
import os import os
import typing as T
import uuid import uuid
from enum import Enum from enum import Enum
@@ -35,37 +37,61 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
class ComponentType(str, Enum): class ComponentType(Enum):
# Basic Segment Types Plain = "Plain" # 纯文本消息
Plain = "Plain" # plain text message Face = "Face" # QQ表情
Image = "Image" # image Record = "Record" # 语音
Record = "Record" # audio Video = "Video" # 视频
Video = "Video" # video At = "At" # At
File = "File" # file attachment Node = "Node" # 转发消息的一个节点
Nodes = "Nodes" # 转发消息的多个节点
Poke = "Poke" # QQ 戳一戳
Image = "Image" # 图片
Reply = "Reply" # 回复
Forward = "Forward" # 转发消息
File = "File" # 文件
# IM-specific Segment Types
Face = "Face" # Emoji segment for Tencent QQ platform
At = "At" # mention a user in IM apps
Node = "Node" # a node in a forwarded message
Nodes = "Nodes" # a forwarded message consisting of multiple nodes
Poke = "Poke" # a poke message for Tencent QQ platform
Reply = "Reply" # a reply message segment
Forward = "Forward" # a forwarded message segment
RPS = "RPS" # TODO RPS = "RPS" # TODO
Dice = "Dice" # TODO Dice = "Dice" # TODO
Shake = "Shake" # TODO Shake = "Shake" # TODO
Anonymous = "Anonymous" # TODO
Share = "Share" Share = "Share"
Contact = "Contact" # TODO Contact = "Contact" # TODO
Location = "Location" # TODO Location = "Location" # TODO
Music = "Music" Music = "Music"
RedBag = "RedBag"
Xml = "Xml"
Json = "Json" Json = "Json"
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown" Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包 WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
class BaseMessageComponent(BaseModel): class BaseMessageComponent(BaseModel):
type: ComponentType type: ComponentType
def toString(self):
output = f"[CQ:{self.type.lower()}"
for k, v in self.__dict__.items():
if k == "type" or v is None:
continue
if k == "_type":
k = "type"
if isinstance(v, bool):
v = 1 if v else 0
output += ",%s=%s" % (
k,
str(v)
.replace("&", "&amp;")
.replace(",", "&#44;")
.replace("[", "&#91;")
.replace("]", "&#93;"),
)
output += "]"
return output
def toDict(self): def toDict(self):
data = {} data = {}
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
@@ -82,22 +108,28 @@ class BaseMessageComponent(BaseModel):
class Plain(BaseMessageComponent): class Plain(BaseMessageComponent):
type = ComponentType.Plain type: ComponentType = "Plain"
text: str text: str
convert: bool | None = True convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
def __init__(self, text: str, convert: bool = True, **_): def __init__(self, text: str, convert: bool = True, **_):
super().__init__(text=text, convert=convert, **_) super().__init__(text=text, convert=convert, **_)
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
if not self.convert:
return self.text
return (
self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
)
def toDict(self): def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}} return {"type": "text", "data": {"text": self.text.strip()}}
async def to_dict(self): async def to_dict(self):
return {"type": "text", "data": {"text": self.text}} return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent): class Face(BaseMessageComponent):
type = ComponentType.Face type: ComponentType = "Face"
id: int id: int
def __init__(self, **_): def __init__(self, **_):
@@ -105,18 +137,18 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent): class Record(BaseMessageComponent):
type = ComponentType.Record type: ComponentType = "Record"
file: str | None = "" file: T.Optional[str] = ""
magic: bool | None = False magic: T.Optional[bool] = False
url: str | None = "" url: T.Optional[str] = ""
cache: bool | None = True cache: T.Optional[bool] = True
proxy: bool | None = True proxy: T.Optional[bool] = True
timeout: int | None = 0 timeout: T.Optional[int] = 0
# 额外 # 额外
path: str | None path: T.Optional[str]
def __init__(self, file: str | None, **_): def __init__(self, file: T.Optional[str], **_):
for k in _: for k in _.keys():
if k == "url": if k == "url":
pass pass
# Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}") # Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}")
@@ -132,25 +164,19 @@ class Record(BaseMessageComponent):
return Record(file=url, **_) return Record(file=url, **_)
raise Exception("not a valid url") raise Exception("not a valid url")
@staticmethod
def fromBase64(bs64_data: str, **_):
return Record(file=f"base64://{bs64_data}", **_)
async def convert_to_file_path(self) -> str: async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns: Returns:
str: 语音的本地路径,以绝对路径表示。 str: 语音的本地路径,以绝对路径表示。
""" """
if not self.file: if self.file and self.file.startswith("file:///"):
raise Exception(f"not a valid file: {self.file}") file_path = self.file[8:]
if self.file.startswith("file:///"): return file_path
return self.file[8:] elif self.file and self.file.startswith("http"):
if self.file.startswith("http"):
file_path = await download_image_by_url(self.file) file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path) return os.path.abspath(file_path)
if self.file.startswith("base64://"): elif self.file and self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://") bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data) image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -158,26 +184,25 @@ class Record(BaseMessageComponent):
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
return os.path.abspath(file_path) return os.path.abspath(file_path)
if os.path.exists(self.file): elif os.path.exists(self.file):
return os.path.abspath(self.file) file_path = self.file
raise Exception(f"not a valid file: {self.file}") return os.path.abspath(file_path)
else:
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str: async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Returns: Returns:
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
""" """
# convert to base64 # convert to base64
if not self.file: if self.file and self.file.startswith("file:///"):
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:]) bs64_data = file_to_base64(self.file[8:])
elif self.file.startswith("http"): elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file) file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path) bs64_data = file_to_base64(file_path)
elif self.file.startswith("base64://"): elif self.file and self.file.startswith("base64://"):
bs64_data = self.file bs64_data = self.file
elif os.path.exists(self.file): elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file) bs64_data = file_to_base64(self.file)
@@ -187,14 +212,14 @@ class Record(BaseMessageComponent):
return bs64_data return bs64_data
async def register_to_file_service(self) -> str: async def register_to_file_service(self) -> str:
"""将语音注册到文件服务。 """
将语音注册到文件服务。
Returns: Returns:
str: 注册后的URL str: 注册后的URL
Raises: Raises:
Exception: 如果未配置 callback_api_base Exception: 如果未配置 callback_api_base
""" """
callback_host = astrbot_config.get("callback_api_base") callback_host = astrbot_config.get("callback_api_base")
@@ -211,12 +236,12 @@ class Record(BaseMessageComponent):
class Video(BaseMessageComponent): class Video(BaseMessageComponent):
type = ComponentType.Video type: ComponentType = "Video"
file: str file: str
cover: str | None = "" cover: T.Optional[str] = ""
c: int | None = 2 c: T.Optional[int] = 2
# 额外 # 额外
path: str | None = "" path: T.Optional[str] = ""
def __init__(self, file: str, **_): def __init__(self, file: str, **_):
super().__init__(file=file, **_) super().__init__(file=file, **_)
@@ -236,31 +261,32 @@ class Video(BaseMessageComponent):
Returns: Returns:
str: 视频的本地路径,以绝对路径表示。 str: 视频的本地路径,以绝对路径表示。
""" """
url = self.file url = self.file
if url and url.startswith("file:///"): if url and url.startswith("file:///"):
return url[8:] return url[8:]
if url and url.startswith("http"): elif url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path) await download_file(url, video_file_path)
if os.path.exists(video_file_path): if os.path.exists(video_file_path):
return os.path.abspath(video_file_path) return os.path.abspath(video_file_path)
raise Exception(f"download failed: {url}") else:
if os.path.exists(url): raise Exception(f"download failed: {url}")
elif os.path.exists(url):
return os.path.abspath(url) return os.path.abspath(url)
raise Exception(f"not a valid file: {url}") else:
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self): async def register_to_file_service(self):
"""将视频注册到文件服务。 """
将视频注册到文件服务。
Returns: Returns:
str: 注册后的URL str: 注册后的URL
Raises: Raises:
Exception: 如果未配置 callback_api_base Exception: 如果未配置 callback_api_base
""" """
callback_host = astrbot_config.get("callback_api_base") callback_host = astrbot_config.get("callback_api_base")
@@ -296,9 +322,9 @@ class Video(BaseMessageComponent):
class At(BaseMessageComponent): class At(BaseMessageComponent):
type = ComponentType.At type: ComponentType = "At"
qq: int | str # 此处str为all时代表所有人 qq: T.Union[int, str] # 此处str为all时代表所有人
name: str | None = "" name: T.Optional[str] = ""
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
@@ -318,66 +344,74 @@ class AtAll(At):
class RPS(BaseMessageComponent): # TODO class RPS(BaseMessageComponent): # TODO
type = ComponentType.RPS type: ComponentType = "RPS"
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
class Dice(BaseMessageComponent): # TODO class Dice(BaseMessageComponent): # TODO
type = ComponentType.Dice type: ComponentType = "Dice"
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
class Shake(BaseMessageComponent): # TODO class Shake(BaseMessageComponent): # TODO
type = ComponentType.Shake type: ComponentType = "Shake"
def __init__(self, **_):
super().__init__(**_)
class Anonymous(BaseMessageComponent): # TODO
type: ComponentType = "Anonymous"
ignore: T.Optional[bool] = False
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
class Share(BaseMessageComponent): class Share(BaseMessageComponent):
type = ComponentType.Share type: ComponentType = "Share"
url: str url: str
title: str title: str
content: str | None = "" content: T.Optional[str] = ""
image: str | None = "" image: T.Optional[str] = ""
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
class Contact(BaseMessageComponent): # TODO class Contact(BaseMessageComponent): # TODO
type = ComponentType.Contact type: ComponentType = "Contact"
_type: str # type 字段冲突 _type: str # type 字段冲突
id: int | None = 0 id: T.Optional[int] = 0
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
class Location(BaseMessageComponent): # TODO class Location(BaseMessageComponent): # TODO
type = ComponentType.Location type: ComponentType = "Location"
lat: float lat: float
lon: float lon: float
title: str | None = "" title: T.Optional[str] = ""
content: str | None = "" content: T.Optional[str] = ""
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
class Music(BaseMessageComponent): class Music(BaseMessageComponent):
type = ComponentType.Music type: ComponentType = "Music"
_type: str _type: str
id: int | None = 0 id: T.Optional[int] = 0
url: str | None = "" url: T.Optional[str] = ""
audio: str | None = "" audio: T.Optional[str] = ""
title: str | None = "" title: T.Optional[str] = ""
content: str | None = "" content: T.Optional[str] = ""
image: str | None = "" image: T.Optional[str] = ""
def __init__(self, **_): def __init__(self, **_):
# for k in _.keys(): # for k in _.keys():
@@ -387,19 +421,19 @@ class Music(BaseMessageComponent):
class Image(BaseMessageComponent): class Image(BaseMessageComponent):
type = ComponentType.Image type: ComponentType = "Image"
file: str | None = "" file: T.Optional[str] = ""
_type: str | None = "" _type: T.Optional[str] = ""
subType: int | None = 0 subType: T.Optional[int] = 0
url: str | None = "" url: T.Optional[str] = ""
cache: bool | None = True cache: T.Optional[bool] = True
id: int | None = 40000 id: T.Optional[int] = 40000
c: int | None = 2 c: T.Optional[int] = 2
# 额外 # 额外
path: str | None = "" path: T.Optional[str] = ""
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: str | None, **_): def __init__(self, file: T.Optional[str], **_):
super().__init__(file=file, **_) super().__init__(file=file, **_)
@staticmethod @staticmethod
@@ -429,17 +463,15 @@ class Image(BaseMessageComponent):
Returns: Returns:
str: 图片的本地路径,以绝对路径表示。 str: 图片的本地路径,以绝对路径表示。
""" """
url = self.url or self.file url = self.url if self.url else self.file
if not url: if url and url.startswith("file:///"):
raise ValueError("No valid file or URL provided") image_file_path = url[8:]
if url.startswith("file:///"): return image_file_path
return url[8:] elif url and url.startswith("http"):
if url.startswith("http"):
image_file_path = await download_image_by_url(url) image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path) return os.path.abspath(image_file_path)
if url.startswith("base64://"): elif url and url.startswith("base64://"):
bs64_data = url.removeprefix("base64://") bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data) image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -447,27 +479,26 @@ class Image(BaseMessageComponent):
with open(image_file_path, "wb") as f: with open(image_file_path, "wb") as f:
f.write(image_bytes) f.write(image_bytes)
return os.path.abspath(image_file_path) return os.path.abspath(image_file_path)
if os.path.exists(url): elif os.path.exists(url):
return os.path.abspath(url) image_file_path = url
raise Exception(f"not a valid file: {url}") return os.path.abspath(image_file_path)
else:
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str: async def convert_to_base64(self) -> str:
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。 """将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
Returns: Returns:
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
""" """
# convert to base64 # convert to base64
url = self.url or self.file url = self.url if self.url else self.file
if not url: if url and url.startswith("file:///"):
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
bs64_data = file_to_base64(url[8:]) bs64_data = file_to_base64(url[8:])
elif url.startswith("http"): elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url) image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path) bs64_data = file_to_base64(image_file_path)
elif url.startswith("base64://"): elif url and url.startswith("base64://"):
bs64_data = url bs64_data = url
elif os.path.exists(url): elif os.path.exists(url):
bs64_data = file_to_base64(url) bs64_data = file_to_base64(url)
@@ -477,14 +508,14 @@ class Image(BaseMessageComponent):
return bs64_data return bs64_data
async def register_to_file_service(self) -> str: async def register_to_file_service(self) -> str:
"""将图片注册到文件服务。 """
将图片注册到文件服务。
Returns: Returns:
str: 注册后的URL str: 注册后的URL
Raises: Raises:
Exception: 如果未配置 callback_api_base Exception: 如果未配置 callback_api_base
""" """
callback_host = astrbot_config.get("callback_api_base") callback_host = astrbot_config.get("callback_api_base")
@@ -501,35 +532,43 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent): class Reply(BaseMessageComponent):
type = ComponentType.Reply type: ComponentType = "Reply"
id: str | int id: T.Union[str, int]
"""所引用的消息 ID""" """所引用的消息 ID"""
chain: list["BaseMessageComponent"] | None = [] chain: T.Optional[T.List["BaseMessageComponent"]] = []
"""被引用的消息段列表""" """被引用的消息段列表"""
sender_id: int | None | str = 0 sender_id: T.Optional[int] | T.Optional[str] = 0
"""被引用的消息对应的发送者的 ID""" """被引用的消息对应的发送者的 ID"""
sender_nickname: str | None = "" sender_nickname: T.Optional[str] = ""
"""被引用的消息对应的发送者的昵称""" """被引用的消息对应的发送者的昵称"""
time: int | None = 0 time: T.Optional[int] = 0
"""被引用的消息发送时间""" """被引用的消息发送时间"""
message_str: str | None = "" message_str: T.Optional[str] = ""
"""被引用的消息解析后的纯文本消息字符串""" """被引用的消息解析后的纯文本消息字符串"""
text: str | None = "" text: T.Optional[str] = ""
"""deprecated""" """deprecated"""
qq: int | None = 0 qq: T.Optional[int] = 0
"""deprecated""" """deprecated"""
seq: int | None = 0 seq: T.Optional[int] = 0
"""deprecated""" """deprecated"""
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
class RedBag(BaseMessageComponent):
type: ComponentType = "RedBag"
title: str
def __init__(self, **_):
super().__init__(**_)
class Poke(BaseMessageComponent): class Poke(BaseMessageComponent):
type: str = ComponentType.Poke type: str = ""
id: int | None = 0 id: T.Optional[int] = 0
qq: int | None = 0 qq: T.Optional[int] = 0
def __init__(self, type: str, **_): def __init__(self, type: str, **_):
type = f"Poke:{type}" type = f"Poke:{type}"
@@ -537,7 +576,7 @@ class Poke(BaseMessageComponent):
class Forward(BaseMessageComponent): class Forward(BaseMessageComponent):
type = ComponentType.Forward type: ComponentType = "Forward"
id: str id: str
def __init__(self, **_): def __init__(self, **_):
@@ -547,13 +586,13 @@ class Forward(BaseMessageComponent):
class Node(BaseMessageComponent): class Node(BaseMessageComponent):
"""群合并转发消息""" """群合并转发消息"""
type = ComponentType.Node type: ComponentType = "Node"
id: int | None = 0 # 忽略 id: T.Optional[int] = 0 # 忽略
name: str | None = "" # qq昵称 name: T.Optional[str] = "" # qq昵称
uin: str | None = "0" # qq号 uin: T.Optional[str] = "0" # qq号
content: list[BaseMessageComponent] | None = [] content: T.Optional[list[BaseMessageComponent]] = []
seq: str | list | None = "" # 忽略 seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: int | None = 0 # 忽略 time: T.Optional[int] = 0 # 忽略
def __init__(self, content: list[BaseMessageComponent], **_): def __init__(self, content: list[BaseMessageComponent], **_):
if isinstance(content, Node): if isinstance(content, Node):
@@ -571,7 +610,7 @@ class Node(BaseMessageComponent):
{ {
"type": comp.type.lower(), "type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"}, "data": {"file": f"base64://{bs64}"},
}, }
) )
elif isinstance(comp, Plain): elif isinstance(comp, Plain):
# For Plain segments, we need to handle the plain differently # For Plain segments, we need to handle the plain differently
@@ -599,10 +638,10 @@ class Node(BaseMessageComponent):
class Nodes(BaseMessageComponent): class Nodes(BaseMessageComponent):
type = ComponentType.Nodes type: ComponentType = "Nodes"
nodes: list[Node] nodes: T.List[Node]
def __init__(self, nodes: list[Node], **_): def __init__(self, nodes: T.List[Node], **_):
super().__init__(nodes=nodes, **_) super().__init__(nodes=nodes, **_)
def toDict(self): def toDict(self):
@@ -624,10 +663,19 @@ class Nodes(BaseMessageComponent):
return ret return ret
class Xml(BaseMessageComponent):
type: ComponentType = "Xml"
data: str
resid: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
class Json(BaseMessageComponent): class Json(BaseMessageComponent):
type = ComponentType.Json type: ComponentType = "Json"
data: str | dict data: T.Union[str, dict]
resid: int | None = 0 resid: T.Optional[int] = 0
def __init__(self, data, **_): def __init__(self, data, **_):
if isinstance(data, dict): if isinstance(data, dict):
@@ -635,18 +683,50 @@ class Json(BaseMessageComponent):
super().__init__(data=data, **_) super().__init__(data=data, **_)
class Unknown(BaseMessageComponent): class CardImage(BaseMessageComponent):
type = ComponentType.Unknown type: ComponentType = "CardImage"
file: str
cache: T.Optional[bool] = True
minwidth: T.Optional[int] = 400
minheight: T.Optional[int] = 400
maxwidth: T.Optional[int] = 500
maxheight: T.Optional[int] = 500
source: T.Optional[str] = ""
icon: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
@staticmethod
def fromFileSystem(path, **_):
return CardImage(file=f"file:///{os.path.abspath(path)}", **_)
class TTS(BaseMessageComponent):
type: ComponentType = "TTS"
text: str text: str
def __init__(self, **_):
super().__init__(**_)
class Unknown(BaseMessageComponent):
type: ComponentType = "Unknown"
text: str
def toString(self):
return ""
class File(BaseMessageComponent): class File(BaseMessageComponent):
"""文件消息段""" """
文件消息段
"""
type = ComponentType.File type: ComponentType = "File"
name: str | None = "" # 名字 name: T.Optional[str] = "" # 名字
file_: str | None = "" # 本地路径 file_: T.Optional[str] = "" # 本地路径
url: str | None = "" # url url: T.Optional[str] = "" # url
def __init__(self, name: str, file: str = "", url: str = ""): def __init__(self, name: str, file: str = "", url: str = ""):
"""文件消息段。""" """文件消息段。"""
@@ -654,11 +734,11 @@ class File(BaseMessageComponent):
@property @property
def file(self) -> str: def file(self) -> str:
"""获取文件路径如果文件不存在但有URL则同步下载文件 """
获取文件路径如果文件不存在但有URL则同步下载文件
Returns: Returns:
str: 文件路径 str: 文件路径
""" """
if self.file_ and os.path.exists(self.file_): if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_) return os.path.abspath(self.file_)
@@ -668,16 +748,19 @@ class File(BaseMessageComponent):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_running(): if loop.is_running():
logger.warning( logger.warning(
"不可以在异步上下文中同步等待下载! " (
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。" "不可以在异步上下文中同步等待下载! "
"请使用 await get_file() 代替直接获取 <File>.file 字段", "这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段"
)
) )
return "" return ""
# 等待下载完成 else:
loop.run_until_complete(self._download_file()) # 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_): if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_) return os.path.abspath(self.file_)
except Exception as e: except Exception as e:
logger.error(f"文件下载失败: {e}") logger.error(f"文件下载失败: {e}")
@@ -685,11 +768,11 @@ class File(BaseMessageComponent):
@file.setter @file.setter
def file(self, value: str): def file(self, value: str):
"""向前兼容, 设置file属性, 传入的参数可能是文件路径或URL """
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args: Args:
value (str): 文件路径或URL value (str): 文件路径或URL
""" """
if value.startswith("http://") or value.startswith("https://"): if value.startswith("http://") or value.startswith("https://"):
self.url = value self.url = value
@@ -704,7 +787,6 @@ class File(BaseMessageComponent):
注意,如果为 True也可能返回文件路径。 注意,如果为 True也可能返回文件路径。
Returns: Returns:
str: 文件路径或者 http 下载链接 str: 文件路径或者 http 下载链接
""" """
if allow_return_url and self.url: if allow_return_url and self.url:
return self.url return self.url
@@ -727,14 +809,14 @@ class File(BaseMessageComponent):
self.file_ = os.path.abspath(file_path) self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self): async def register_to_file_service(self):
"""将文件注册到文件服务。 """
将文件注册到文件服务。
Returns: Returns:
str: 注册后的URL str: 注册后的URL
Raises: Raises:
Exception: 如果未配置 callback_api_base Exception: 如果未配置 callback_api_base
""" """
callback_host = astrbot_config.get("callback_api_base") callback_host = astrbot_config.get("callback_api_base")
@@ -771,39 +853,42 @@ class File(BaseMessageComponent):
class WechatEmoji(BaseMessageComponent): class WechatEmoji(BaseMessageComponent):
type = ComponentType.WechatEmoji type: ComponentType = "WechatEmoji"
md5: str | None = "" md5: T.Optional[str] = ""
md5_len: int | None = 0 md5_len: T.Optional[int] = 0
cdnurl: str | None = "" cdnurl: T.Optional[str] = ""
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
ComponentTypes = { ComponentTypes = {
# Basic Message Segments
"plain": Plain, "plain": Plain,
"text": Plain, "text": Plain,
"image": Image, "face": Face,
"record": Record, "record": Record,
"video": Video, "video": Video,
"file": File,
# IM-specific Message Segments
"face": Face,
"at": At, "at": At,
"rps": RPS, "rps": RPS,
"dice": Dice, "dice": Dice,
"shake": Shake, "shake": Shake,
"anonymous": Anonymous,
"share": Share, "share": Share,
"contact": Contact, "contact": Contact,
"location": Location, "location": Location,
"music": Music, "music": Music,
"image": Image,
"reply": Reply, "reply": Reply,
"redbag": RedBag,
"poke": Poke, "poke": Poke,
"forward": Forward, "forward": Forward,
"node": Node, "node": Node,
"nodes": Nodes, "nodes": Nodes,
"xml": Xml,
"json": Json, "json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown, "unknown": Unknown,
"file": File,
"WechatEmoji": WechatEmoji, "WechatEmoji": WechatEmoji,
} }

View File

@@ -1,16 +1,15 @@
import enum import enum
from collections.abc import AsyncGenerator
from typing import List, Optional, Union, AsyncGenerator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing_extensions import deprecated
from astrbot.core.message.components import ( from astrbot.core.message.components import (
BaseMessageComponent,
Plain,
Image,
At, At,
AtAll, AtAll,
BaseMessageComponent,
Image,
Plain,
) )
from typing_extensions import deprecated
@dataclass @dataclass
@@ -21,18 +20,18 @@ class MessageChain:
Attributes: Attributes:
`chain` (list): 用于顺序存储各个组件。 `chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
""" """
chain: list[BaseMessageComponent] = field(default_factory=list) chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: bool | None = None # None 为跟随用户设置 use_t2i_: Optional[bool] = None # None 为跟随用户设置
type: str | None = None type: Optional[str] = None
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
def message(self, message: str): def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。 """添加一条文本消息到消息链 `chain` 中。
Example: Example:
CommandResult().message("Hello ").message("world!") CommandResult().message("Hello ").message("world!")
# 输出 Hello world! # 输出 Hello world!
@@ -40,10 +39,11 @@ class MessageChain:
self.chain.append(Plain(message)) self.chain.append(Plain(message))
return self return self
def at(self, name: str, qq: str | int): def at(self, name: str, qq: Union[str, int]):
"""添加一条 At 消息到消息链 `chain` 中。 """添加一条 At 消息到消息链 `chain` 中。
Example: Example:
CommandResult().at("张三", "12345678910") CommandResult().at("张三", "12345678910")
# 输出 @张三 # 输出 @张三
@@ -55,6 +55,7 @@ class MessageChain:
"""添加一条 AtAll 消息到消息链 `chain` 中。 """添加一条 AtAll 消息到消息链 `chain` 中。
Example: Example:
CommandResult().at_all() CommandResult().at_all()
# 输出 @所有人 # 输出 @所有人
@@ -67,6 +68,7 @@ class MessageChain:
"""添加一条错误消息到消息链 `chain` 中 """添加一条错误消息到消息链 `chain` 中
Example: Example:
CommandResult().error("解析失败") CommandResult().error("解析失败")
""" """
@@ -80,6 +82,7 @@ class MessageChain:
如果需要发送本地图片,请使用 `file_image` 方法。 如果需要发送本地图片,请使用 `file_image` 方法。
Example: Example:
CommandResult().image("https://example.com/image.jpg") CommandResult().image("https://example.com/image.jpg")
""" """
@@ -93,7 +96,6 @@ class MessageChain:
如果需要发送网络图片,请使用 `url_image` 方法。 如果需要发送网络图片,请使用 `url_image` 方法。
CommandResult().image("image.jpg") CommandResult().image("image.jpg")
""" """
self.chain.append(Image.fromFileSystem(path)) self.chain.append(Image.fromFileSystem(path))
return self return self
@@ -112,7 +114,6 @@ class MessageChain:
Args: Args:
use_t2i (bool): 是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 use_t2i (bool): 是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
""" """
self.use_t2i_ = use_t2i self.use_t2i_ = use_t2i
return self return self
@@ -124,7 +125,7 @@ class MessageChain:
def squash_plain(self): def squash_plain(self):
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
if not self.chain: if not self.chain:
return None return
new_chain = [] new_chain = []
first_plain = None first_plain = None
@@ -152,7 +153,6 @@ class EventResultType(enum.Enum):
Attributes: Attributes:
CONTINUE: 事件将会继续传播 CONTINUE: 事件将会继续传播
STOP: 事件将会终止传播 STOP: 事件将会终止传播
""" """
CONTINUE = enum.auto() CONTINUE = enum.auto()
@@ -181,18 +181,17 @@ class MessageEventResult(MessageChain):
`chain` (list): 用于顺序存储各个组件。 `chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`result_type` (EventResultType): 事件处理的结果类型。 `result_type` (EventResultType): 事件处理的结果类型。
""" """
result_type: EventResultType | None = field( result_type: Optional[EventResultType] = field(
default_factory=lambda: EventResultType.CONTINUE, default_factory=lambda: EventResultType.CONTINUE
) )
result_content_type: ResultContentType | None = field( result_content_type: Optional[ResultContentType] = field(
default_factory=lambda: ResultContentType.GENERAL_RESULT, default_factory=lambda: ResultContentType.GENERAL_RESULT
) )
async_stream: AsyncGenerator | None = None async_stream: Optional[AsyncGenerator] = None
"""异步流""" """异步流"""
def stop_event(self) -> "MessageEventResult": def stop_event(self) -> "MessageEventResult":
@@ -206,7 +205,9 @@ class MessageEventResult(MessageChain):
return self return self
def is_stopped(self) -> bool: def is_stopped(self) -> bool:
"""是否终止事件传播。""" """
是否终止事件传播。
"""
return self.result_type == EventResultType.STOP return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
@@ -219,7 +220,6 @@ class MessageEventResult(MessageChain):
Args: Args:
result_type (EventResultType): 事件处理的结果类型。 result_type (EventResultType): 事件处理的结果类型。
""" """
self.result_content_type = typ self.result_content_type = typ
return self return self

View File

@@ -1,8 +1,8 @@
from astrbot import logger
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, Personality from astrbot.core.db.po import Persona, Personality
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.platform.message_session import MessageSession from astrbot.core.platform.message_session import MessageSession
from astrbot import logger
DEFAULT_PERSONALITY = Personality( DEFAULT_PERSONALITY = Personality(
prompt="You are a helpful and friendly assistant.", prompt="You are a helpful and friendly assistant.",
@@ -41,14 +41,12 @@ class PersonaManager:
return persona return persona
async def get_default_persona_v3( async def get_default_persona_v3(
self, self, umo: str | MessageSession | None = None
umo: str | MessageSession | None = None,
) -> Personality: ) -> Personality:
"""获取默认 persona""" """获取默认 persona"""
cfg = self.acm.get_conf(umo) cfg = self.acm.get_conf(umo)
default_persona_id = cfg.get("provider_settings", {}).get( default_persona_id = cfg.get("provider_settings", {}).get(
"default_personality", "default_personality", "default"
"default",
) )
if not default_persona_id or default_persona_id == "default": if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY return DEFAULT_PERSONALITY
@@ -68,19 +66,16 @@ class PersonaManager:
async def update_persona( async def update_persona(
self, self,
persona_id: str, persona_id: str,
system_prompt: str | None = None, system_prompt: str = None,
begin_dialogs: list[str] | None = None, begin_dialogs: list[str] = None,
tools: list[str] | None = None, tools: list[str] = None,
): ):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" """更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id) existing_persona = await self.db.get_persona_by_id(persona_id)
if not existing_persona: if not existing_persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.") raise ValueError(f"Persona with ID {persona_id} does not exist.")
persona = await self.db.update_persona( persona = await self.db.update_persona(
persona_id, persona_id, system_prompt, begin_dialogs, tools=tools
system_prompt,
begin_dialogs,
tools=tools,
) )
if persona: if persona:
for i, p in enumerate(self.personas): for i, p in enumerate(self.personas):
@@ -105,10 +100,7 @@ class PersonaManager:
if await self.db.get_persona_by_id(persona_id): if await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} already exists.") raise ValueError(f"Persona with ID {persona_id} already exists.")
new_persona = await self.db.insert_persona( new_persona = await self.db.insert_persona(
persona_id, persona_id, system_prompt, begin_dialogs, tools=tools
system_prompt,
begin_dialogs,
tools=tools,
) )
self.personas.append(new_persona) self.personas.append(new_persona)
self.get_v3_persona_data() self.get_v3_persona_data()
@@ -123,7 +115,6 @@ class PersonaManager:
- list[dict]: 包含 persona 配置的字典列表。 - list[dict]: 包含 persona 配置的字典列表。
- list[Personality]: 包含 Personality 对象的列表。 - list[Personality]: 包含 Personality 对象的列表。
- Personality: 默认选择的 Personality 对象。 - Personality: 默认选择的 Personality 对象。
""" """
v3_persona_config = [ v3_persona_config = [
{ {
@@ -145,7 +136,7 @@ class PersonaManager:
if begin_dialogs: if begin_dialogs:
if len(begin_dialogs) % 2 != 0: if len(begin_dialogs) % 2 != 0:
logger.error( logger.error(
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。", f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
) )
begin_dialogs = [] begin_dialogs = []
user_turn = True user_turn = True
@@ -155,7 +146,7 @@ class PersonaManager:
"role": "user" if user_turn else "assistant", "role": "user" if user_turn else "assistant",
"content": dialog, "content": dialog,
"_no_save": None, # 不持久化到 db "_no_save": None, # 不持久化到 db
}, }
) )
user_turn = not user_turn user_turn = not user_turn

View File

@@ -27,15 +27,15 @@ STAGES_ORDER = [
] ]
__all__ = [ __all__ = [
"ContentSafetyCheckStage",
"EventResultType",
"MessageEventResult",
"PreProcessStage",
"ProcessStage",
"RateLimitStage",
"RespondStage",
"ResultDecorateStage",
"SessionStatusCheckStage",
"WakingCheckStage", "WakingCheckStage",
"WhitelistCheckStage", "WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
"MessageEventResult",
"EventResultType",
] ]

View File

@@ -1,11 +1,9 @@
from collections.abc import AsyncGenerator from typing import Union, AsyncGenerator
from astrbot.core import logger
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from ..context import PipelineContext
from ..stage import Stage, register_stage from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from .strategies.strategy import StrategySelector from .strategies.strategy import StrategySelector
@@ -21,10 +19,8 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config) self.strategy_selector = StrategySelector(config)
async def process( async def process(
self, self, event: AstrMessageEvent, check_text: str = None
event: AstrMessageEvent, ) -> Union[None, AsyncGenerator[None, None]]:
check_text: str | None = None,
) -> None | AsyncGenerator[None, None]:
"""检查内容安全""" """检查内容安全"""
text = check_text if check_text else event.get_message_str() text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text) ok, info = self.strategy_selector.check(text)
@@ -32,8 +28,8 @@ class ContentSafetyCheckStage(Stage):
if event.is_at_or_wake_command: if event.is_at_or_wake_command:
event.set_result( event.set_result(
MessageEventResult().message( MessageEventResult().message(
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。", "你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"
), )
) )
yield yield
event.stop_event() event.stop_event()

View File

@@ -1,7 +1,8 @@
import abc import abc
from typing import Tuple
class ContentSafetyStrategy(abc.ABC): class ContentSafetyStrategy(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def check(self, content: str) -> tuple[bool, str]: def check(self, content: str) -> Tuple[bool, str]:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,8 +1,9 @@
"""使用此功能应该先 pip install baidu-aip""" """
使用此功能应该先 pip install baidu-aip
from aip import AipContentCensor """
from . import ContentSafetyStrategy from . import ContentSafetyStrategy
from aip import AipContentCensor
class BaiduAipStrategy(ContentSafetyStrategy): class BaiduAipStrategy(ContentSafetyStrategy):
@@ -12,18 +13,18 @@ class BaiduAipStrategy(ContentSafetyStrategy):
self.secret_key = sk self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key) self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
def check(self, content: str) -> tuple[bool, str]: def check(self, content: str):
res = self.client.textCensorUserDefined(content) res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res: if "conclusionType" not in res:
return False, "" return False, ""
if res["conclusionType"] == 1: if res["conclusionType"] == 1:
return True, "" return True, ""
if "data" not in res: else:
return False, "" if "data" not in res:
count = len(res["data"]) return False, ""
parts = [f"百度审核服务发现 {count} 处违规:\n"] count = len(res["data"])
for i in res["data"]: info = f"百度审核服务发现 {count} 处违规:\n"
parts.append(f"{i['msg']}\n") for i in res["data"]:
parts.append("\n判断结果:" + res["conclusion"]) info += f"{i['msg']}\n"
info = "".join(parts) info += "\n判断结果:" + res["conclusion"]
return False, info return False, info

View File

@@ -1,5 +1,4 @@
import re import re
from . import ContentSafetyStrategy from . import ContentSafetyStrategy
@@ -17,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"] # json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# ) # )
def check(self, content: str) -> tuple[bool, str]: def check(self, content: str) -> bool:
for keyword in self.keywords: for keyword in self.keywords:
if re.search(keyword, content): if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。" return False, "内容安全检查不通过,匹配到敏感词。"

View File

@@ -1,16 +1,16 @@
from astrbot import logger
from . import ContentSafetyStrategy from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger
class StrategySelector: class StrategySelector:
def __init__(self, config: dict) -> None: def __init__(self, config: dict) -> None:
self.enabled_strategies: list[ContentSafetyStrategy] = [] self.enabled_strategies: List[ContentSafetyStrategy] = []
if config["internal_keywords"]["enable"]: if config["internal_keywords"]["enable"]:
from .keywords import KeywordsStrategy from .keywords import KeywordsStrategy
self.enabled_strategies.append( self.enabled_strategies.append(
KeywordsStrategy(config["internal_keywords"]["extra_keywords"]), KeywordsStrategy(config["internal_keywords"]["extra_keywords"])
) )
if config["baidu_aip"]["enable"]: if config["baidu_aip"]["enable"]:
try: try:
@@ -23,10 +23,10 @@ class StrategySelector:
config["baidu_aip"]["app_id"], config["baidu_aip"]["app_id"],
config["baidu_aip"]["api_key"], config["baidu_aip"]["api_key"],
config["baidu_aip"]["secret_key"], config["baidu_aip"]["secret_key"],
), )
) )
def check(self, content: str) -> tuple[bool, str]: def check(self, content: str) -> Tuple[bool, str]:
for strategy in self.enabled_strategies: for strategy in self.enabled_strategies:
ok, info = strategy.check(content) ok, info = strategy.check(content)
if not ok: if not ok:

View File

@@ -1,9 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager from astrbot.core.star import PluginManager
from .context_utils import call_handler, call_event_hook
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
@dataclass @dataclass
@@ -15,4 +13,3 @@ class PipelineContext:
astrbot_config_id: str astrbot_config_id: str
call_handler = call_handler call_handler = call_handler
call_event_hook = call_event_hook call_event_hook = call_event_hook
call_local_llm_tool = call_local_llm_tool

Some files were not shown because too many files have changed in this diff Show More