Compare commits

..

4 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
479a737a03 Implement googleSearch tool in openai_source.py for Gemini(OpenAI Compatible) provider
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-17 02:07:36 +00:00
copilot-swe-agent[bot]
e324b69bf1 Add googleSearch tool alias for OpenAI-compatible Gemini API
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:39:20 +00:00
copilot-swe-agent[bot]
df86913fbf Initial setup and analysis of googleSearch tool feature request
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:33:47 +00:00
copilot-swe-agent[bot]
ca77cd6a83 Initial plan 2025-08-16 16:27:25 +00:00
355 changed files with 15198 additions and 39107 deletions

View File

@@ -16,7 +16,7 @@ body:
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
不熟悉 JSON 现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
- type: textarea
id: plugin-info
@@ -26,13 +26,12 @@ body:
value: |
```json
{
"name": "插件名,请以 astrbot_plugin_ 开头",
"display_name": "用于展示的插件名,方便人类阅读",
"desc": "插件的简短介绍",
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": "",
"social_link": ""
}
```
validations:

View File

@@ -6,13 +6,13 @@ body:
- type: markdown
attributes:
value: |
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。
- type: textarea
attributes:
label: 发生了什么
description: 描述你遇到的异常
placeholder: >
一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
一个清晰且具体的描述这个异常是什么。
validations:
required: true
@@ -55,7 +55,7 @@ body:
attributes:
label: 报错日志
description: >
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!
placeholder: >
请提供完整的报错日志或截图。
validations:

View File

@@ -1,46 +1,19 @@
<!-- 如果有的话,指定 PR 旨在解决的 ISSUE 编号。 -->
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
<!-- 如果有的话,指定这个 PR 解决的 ISSUE -->
解决了 #XYZ
fixes #XYZ
### Motivation
---
<!--解释为什么要改动-->
### Motivation / 动机
### Modifications
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
<!--简单解释你的改动-->
### Modifications / 改动点
### Check
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
### Verification Steps / 验证步骤
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤例如1. 导航到... 2. 点击...)。-->
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
### Screenshots or Test Results / 运行截图或测试结果
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
### Compatibility & Breaking Changes / 兼容性与破坏性变更
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
---
### Checklist / 检查清单
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 👀 我的更改经过良好的测试
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。
- [ ] 😮 我的更改没有引入恶意代码

View File

@@ -1,38 +0,0 @@
# Set to true to add reviewers to pull requests
addReviewers: true
# Set to true to add assignees to pull requests
addAssignees: false
# A list of reviewers to be added to pull requests (GitHub user name)
reviewers:
- Soulter
- Raven95676
- Larch-C
- anka-afk
- advent259141
- Fridemn
- LIghtJUNction
# - zouyonghe
# A number of reviewers added to the pull request
# Set 0 to add all the reviewers (default: 0)
numberOfReviewers: 2
# A list of assignees, overrides reviewers if set
# assignees:
# - assigneeA
# A number of assignees to add to the pull request
# Set to 0 to add all of the assignees.
# Uses numberOfReviewers if unset.
# numberOfAssignees: 2
# A list of keywords to be skipped the process that add reviewers if pull requests include it
skipKeywords:
- wip
- draft
# A list of users to be skipped by both the add reviewers and add assignees processes
# skipUsers:
# - dependabot[bot]

View File

@@ -73,7 +73,7 @@ jobs:
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: '3.10'

View File

@@ -1,34 +0,0 @@
name: Code Format Check
on:
pull_request:
branches: [ master ]
push:
branches: [ master ]
jobs:
format-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
- name: Install UV
run: pip install uv
- name: Install dependencies
run: uv sync
- name: Check code formatting with ruff
run: |
uv run ruff format --check .
- name: Check code style with ruff
run: |
uv run ruff check .

View File

@@ -60,7 +60,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -88,6 +88,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
- name: Install dependencies
run: |

View File

@@ -13,18 +13,11 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v5
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 'latest'
- name: npm install, build
run: |
cd dashboard
npm install pnpm -g
pnpm install
pnpm i --save-dev @types/markdown-it
pnpm run build
npm install
npm run build
- name: Inject Commit SHA
id: get_sha
@@ -32,24 +25,11 @@ jobs:
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
mkdir -p dashboard/dist/assets
echo $COMMIT_SHA > dashboard/dist/assets/version
cd dashboard
zip -r dist.zip dist
- name: Archive production artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v4
with:
name: dist-without-markdown
path: |
dashboard/dist
!dist/**/*.md
- name: Create GitHub Release
if: github.event_name == 'push'
uses: ncipollo/release-action@v1
with:
tag: release-${{ github.sha }}
owner: AstrBotDevs
repo: astrbot-release-harbour
body: "Automated release from commit ${{ github.sha }}"
token: ${{ secrets.ASTRBOT_HARBOUR_TOKEN }}
artifacts: "dashboard/dist.zip"

View File

@@ -27,33 +27,6 @@ jobs:
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Check if version is pre-release
id: check-prerelease
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
else
version="${{ github.ref_name }}"
fi
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
echo "is_prerelease=true" >> $GITHUB_OUTPUT
echo "Version $version is a pre-release, will not push latest tag"
else
echo "is_prerelease=false" >> $GITHUB_OUTPUT
echo "Version $version is a stable release, will push latest tag"
fi
- name: Build Dashboard
run: |
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
- name: Set QEMU
uses: docker/setup-qemu-action@v3
@@ -80,9 +53,9 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
ghcr.io/soulter/astrbot:latest
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- name: Post build notifications

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v10
- uses: actions/stale@v9
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'

4
.gitignore vendored
View File

@@ -30,6 +30,4 @@ packages/python_interpreter/workplace
.conda/
.idea
pytest.ini
.astrbot
uv.lock
.astrbot

View File

@@ -6,20 +6,8 @@ ci:
autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format

View File

@@ -4,6 +4,8 @@ WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
nodejs \
npm \
gcc \
build-essential \
python3-dev \
@@ -11,20 +13,23 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libssl-dev \
ca-certificates \
bash \
ffmpeg \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y curl gnupg && \
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pilk --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
EXPOSE 6185
# 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
EXPOSE 6185
EXPOSE 6186
CMD [ "python", "main.py" ]

203
README.md
View File

@@ -1,40 +1,33 @@
<p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
<div align="center">
<br>
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7日消息量&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
<br>
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路线图</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
## 主要功能
## ✨ 主要功能
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
@@ -42,9 +35,9 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
5. **WebUI**。可视化配置和管理机器人,功能齐全。
## 部署方式
## ✨ 使用方式
#### Docker 部署(推荐 🥳)
#### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
@@ -72,7 +65,7 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### Windows 一键安装器部署
@@ -86,7 +79,9 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
#### 手动部署
首先安装 uv
> 推荐使用 `uv`。
首先,安装 uv
```bash
pip install uv
@@ -101,101 +96,53 @@ uv run main.py
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## 🌍 社区
### QQ 群组
- 1 群322154837
- 3 群630166526
- 5 群822130018
- 6 群753075035
- 开发者群975206796
### Telegram 群组
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群组
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ⚡ 消息平台支持情况
**官方维护**
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方平台) | ✔ |
| QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企微应用 | ✔ |
| 企微智能机器人 | ✔ |
| 企业微信 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | |
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | |
| 微信对话开放平台 | 🚧 |
| WhatsApp | 🚧 |
| 小爱音响 | 🚧 |
## ⚡ 提供商支持情况
**大模型服务**
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
## ❤️ 贡献
@@ -210,11 +157,44 @@ uv run main.py
AstrBot 使用 `ruff` 进行代码格式化和检查。
```bash
git clone https://github.com/AstrBotDevs/AstrBot
git clone https://github.com/Soulter/AstrBot
pip install pre-commit
pre-commit install
```
## 🌟 支持
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
## ✨ Demo
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器Beta 测试)✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
_✨ WebUI ✨_
</div>
</details>
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
@@ -223,21 +203,24 @@ pre-commit install
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
此外,本项目的诞生离不开以下开源项目的帮助
此外,本项目的诞生离不开以下开源项目:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
</details>
![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097)
_私は、高性能ですから!_

View File

@@ -10,16 +10,16 @@ _✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.app/">Documentation</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a>
<a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
</div>
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
@@ -49,7 +49,7 @@ Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app
#### Replit Deployment
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS Deployment
@@ -67,8 +67,8 @@ See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| Feishu | ✔ | Group chats | Text, Images |
| WeChat Open Platform | 🚧 | Planned | - |
| Discord | 🚧 | Planned | - |
@@ -157,7 +157,7 @@ _✨ Built-in Web Chat Interface ✨_
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=AstrBotDevs/AstrBot&type=Date)](https://star-history.com/#AstrBotDevs/AstrBot&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
@@ -169,7 +169,7 @@ _✨ Built-in Web Chat Interface ✨_
<!-- ## ✨ ATRI [Beta]
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri)
Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
2. Long-term memory

View File

@@ -10,16 +10,16 @@ _✨ 簡単に使えるマルチプラットフォーム LLM チャットボッ
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.app/">ドキュメントを見る</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a>
<a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
</div>
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデルLLM接続機能を備えたチャットボットおよび開発フレームワークです。
@@ -50,7 +50,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
#### Replit デプロイ
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS デプロイ

View File

View File

@@ -3,18 +3,5 @@ from astrbot import logger
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_agent as agent
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
__all__ = [
"AstrBotConfig",
"logger",
"html_renderer",
"llm_tool",
"agent",
"sp",
"ToolSet",
"FunctionTool",
"BaseFunctionToolExecutor",
]
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]

View File

@@ -7,7 +7,6 @@ from astrbot.core.star.register import (
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_platform_loaded as on_platform_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
@@ -42,7 +41,6 @@ __all__ = [
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_platform_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",

View File

@@ -37,10 +37,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
):
click.echo("正在安装管理面板...")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
click.echo("管理面板安装完成")
@@ -53,10 +50,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
@@ -65,10 +59,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
click.echo("初始化管理面板目录...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"),
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
)
click.echo("管理面板初始化完成")
except Exception as e:

View File

@@ -1,4 +1,5 @@
import os
import asyncio
from .log import LogManager, LogBroker # noqa
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
@@ -20,7 +21,7 @@ html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
db_helper = SQLiteDatabase(DB_PATH)
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences(db_helper=db_helper)
sp = SharedPreferences()
# 文件令牌服务
file_token_service = FileTokenService()
pip_installer = PipInstaller(

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
from .tool import FunctionTool
from typing import Generic
from .run_context import TContext
from .hooks import BaseAgentRunHooks
@dataclass
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str | FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -1,34 +0,0 @@
from typing import Generic
from .tool import FunctionTool
from .agent import Agent
from .run_context import TContext
class HandoffTool(FunctionTool, Generic[TContext]):
"""Handoff tool for delegating tasks to another agent."""
def __init__(
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
):
self.agent = agent
super().__init__(
name=f"transfer_to_{agent.name}",
parameters=parameters or self.default_parameters(),
description=agent.instructions or self.default_description(agent.name),
**kwargs,
)
def default_parameters(self) -> dict:
return {
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "The input to be handed off to another agent. This should be a clear and concise request or task.",
},
},
}
def default_description(self, agent_name: str | None) -> str:
agent_name = agent_name or "another"
return f"Delegate tasks to {self.name} agent to handle the request."

View File

@@ -1,27 +0,0 @@
import mcp
from dataclasses import dataclass
from .run_context import ContextWrapper, TContext
from typing import Generic
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.agent.tool import FunctionTool
@dataclass
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
): ...
async def on_tool_end(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
tool_result: mcp.types.CallToolResult | None,
): ...
async def on_agent_done(
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
): ...

View File

@@ -1,224 +0,0 @@
import asyncio
import logging
from datetime import timedelta
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core.utils.log_pipe import LogPipe
try:
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
async with aiohttp.ClientSession() as session:
if transport_type == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在:
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
# 处理 MCP 服务的错误日志
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
if transport_type != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
), # type: ignore
),
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData

View File

@@ -1,18 +0,0 @@
from dataclasses import dataclass
from typing import Any, Generic
from typing_extensions import TypeVar
from astrbot.core.platform.astr_message_event import AstrMessageEvent
TContext = TypeVar("TContext", default=Any)
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
event: AstrMessageEvent
NoContext = ContextWrapper[None]

View File

@@ -1,3 +0,0 @@
from .base import BaseAgentRunner
__all__ = ["BaseAgentRunner"]

View File

@@ -1,267 +0,0 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Callable, Literal, Any, Optional
from .mcp_client import MCPClient
@dataclass
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str
parameters: dict | None = None
description: str | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str | None = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient | None = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
def __dict__(self) -> dict[str, Any]:
"""将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
class ToolSet:
"""A set of function tools that can be used in function calling.
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool):
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
if existing_tool.name == tool.name:
self.tools[i] = tool
return
self.tools.append(tool)
def remove_tool(self, name: str):
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
def get_tool(self, name: str) -> Optional[FunctionTool]:
"""Get a tool by its name."""
for tool in self.tools:
if tool.name == name:
return tool
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FunctionTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.add_tool(_func)
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
def remove_func(self, name: str):
"""Remove a function tool by its name."""
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> FunctionTool | None:
"""Get all function tools."""
return self.get_tool(name)
@property
def func_list(self) -> list[FunctionTool]:
"""Get the list of function tools."""
return self.tools
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
"""Convert tools to OpenAI API function calling schema format."""
result = []
for tool in self.tools:
func_def = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
},
}
if (
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result
def anthropic_schema(self) -> list[dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": input_schema,
}
result.append(tool_def)
return result
def google_schema(self) -> dict:
"""Convert tools to Google GenAI API format."""
def convert_schema(schema: dict) -> dict:
"""Convert schema to Gemini API format."""
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties:
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = []
for tool in self.tools:
d = {
"name": tool.name,
"description": tool.description,
}
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
return self.openai_schema(omit_empty_parameter_field)
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
def get_func_desc_anthropic_style(self):
return self.anthropic_schema()
@deprecated(reason="Use google_schema() instead", version="4.0.0")
def get_func_desc_google_genai_style(self):
return self.google_schema()
def names(self) -> list[str]:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]
def __len__(self):
return len(self.tools)
def __bool__(self):
return len(self.tools) > 0
def __iter__(self):
return iter(self.tools)
def __repr__(self):
return f"ToolSet(tools={self.tools})"
def __str__(self):
return f"ToolSet(tools={self.tools})"

View File

@@ -1,11 +0,0 @@
import mcp
from typing import Any, Generic, AsyncGenerator
from .run_context import TContext, ContextWrapper
from .tool import FunctionTool
class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod
async def execute(
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...

View File

@@ -1,12 +0,0 @@
from dataclasses import dataclass
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@dataclass
class AstrAgentContext:
provider: Provider
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -1,255 +0,0 @@
import os
import uuid
from astrbot.core import AstrBotConfig, logger
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from typing import TypeVar, TypedDict
_VT = TypeVar("_VT")
class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default"
name: str
path: str # File name to the configuration file
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
id="default",
name="default",
path=ASTRBOT_CONFIG_PATH,
)
class AstrBotConfigManager:
"""A class to manage the system configuration of AstrBot, aka ACM"""
def __init__(
self,
default_config: AstrBotConfig,
ucr: UmopConfigRouter,
sp: SharedPreferences,
):
self.sp = sp
self.ucr = ucr
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
self.abconf_data = None
self._load_all_configs()
def _get_abconf_data(self) -> dict:
"""获取所有的 abconf 数据"""
if self.abconf_data is None:
self.abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
return self.abconf_data
def _load_all_configs(self):
"""Load all configurations from the shared preferences."""
abconf_data = self._get_abconf_data()
self.abconf_data = abconf_data
for uuid_, meta in abconf_data.items():
filename = meta["path"]
conf_path = os.path.join(get_astrbot_config_path(), filename)
if os.path.exists(conf_path):
conf = AstrBotConfig(config_path=conf_path)
self.confs[uuid_] = conf
else:
logger.warning(
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
)
continue
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "path": str, "name": str }
abconf_data = self._get_abconf_data()
if isinstance(umo, MessageSession):
umo = str(umo)
else:
try:
umo = str(MessageSession.from_str(umo)) # validate
except Exception:
return DEFAULT_CONFIG_CONF_INFO
conf_id = self.ucr.get_conf_id_for_umop(umo)
if conf_id:
meta = abconf_data.get(conf_id)
if meta and isinstance(meta, dict):
# the bind relation between umo and conf is defined in ucr now, so we remove "umop" here
meta.pop("umop", None)
return ConfInfo(**meta, id=conf_id)
return DEFAULT_CONFIG_CONF_INFO
def _save_conf_mapping(
self,
abconf_path: str,
abconf_id: str,
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"path": abconf_path,
"name": random_word,
}
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
if not umo:
return self.confs["default"]
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
uuid_ = self._load_conf_mapping(umo)["id"]
conf = self.confs.get(uuid_)
if not conf:
conf = self.confs["default"] # default MUST exists
return conf
@property
def default_conf(self) -> AstrBotConfig:
"""获取默认配置文件"""
return self.confs["default"]
def get_conf_info(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件元数据"""
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
return self._load_conf_mapping(umo)
def get_conf_list(self) -> list[ConfInfo]:
"""获取所有配置文件的元数据列表"""
conf_list = []
abconf_mapping = self._get_abconf_data()
for uuid_, meta in abconf_mapping.items():
if not isinstance(meta, dict):
continue
meta.pop("umop", None)
conf_list.append(ConfInfo(**meta, id=uuid_))
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
return conf_list
def create_conf(
self,
config: dict = DEFAULT_CONFIG,
name: str | None = None,
) -> str:
conf_uuid = str(uuid.uuid4())
conf_file_name = f"abconf_{conf_uuid}.json"
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
conf = AstrBotConfig(config_path=conf_path, default_config=config)
conf.save_config()
self._save_conf_mapping(conf_file_name, conf_uuid, abconf_name=name)
self.confs[conf_uuid] = conf
return conf_uuid
def delete_conf(self, conf_id: str) -> bool:
"""删除指定配置文件
Args:
conf_id: 配置文件的 UUID
Returns:
bool: 删除是否成功
Raises:
ValueError: 如果试图删除默认配置文件
"""
if conf_id == "default":
raise ValueError("不能删除默认配置文件")
# 从映射中移除
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 获取配置文件路径
conf_path = os.path.join(
get_astrbot_config_path(), abconf_data[conf_id]["path"]
)
# 删除配置文件
try:
if os.path.exists(conf_path):
os.remove(conf_path)
logger.info(f"已删除配置文件: {conf_path}")
except Exception as e:
logger.error(f"删除配置文件 {conf_path} 失败: {e}")
return False
# 从内存中移除
if conf_id in self.confs:
del self.confs[conf_id]
# 从映射中移除
del abconf_data[conf_id]
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功删除配置文件 {conf_id}")
return True
def update_conf_info(self, conf_id: str, name: str | None = None) -> bool:
"""更新配置文件信息
Args:
conf_id: 配置文件的 UUID
name: 新的配置文件名称 (可选)
Returns:
bool: 更新是否成功
"""
if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息")
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 更新名称
if name is not None:
abconf_data[conf_id]["name"] = name
# 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功更新配置文件 {conf_id} 的信息")
return True
def g(
self, umo: str | None = None, key: str | None = None, default: _VT = None
) -> _VT:
"""获取配置项。umo 为 None 时使用默认配置"""
if umo is None:
return self.confs["default"].get(key, default)
conf = self.get_conf(umo)
return conf.get(key, default)

File diff suppressed because it is too large Load Diff

View File

@@ -5,76 +5,40 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List, Callable, Awaitable
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
from astrbot.core.db.po import Conversation
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: Dict[str, str] = {}
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
# 会话删除回调函数列表(用于级联清理,如知识库配置)
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
def register_on_session_deleted(
self, callback: Callable[[str], Awaitable[None]]
) -> None:
"""注册会话删除回调函数
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
其他模块可以注册回调来响应会话删除事件,实现级联清理。
例如:知识库模块可以注册回调来清理会话的知识库配置。
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
Args:
callback: 回调函数接收会话ID (unified_msg_origin) 作为参数
"""
self._on_session_deleted_callbacks.append(callback)
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
"""触发会话删除回调
Args:
unified_msg_origin: 会话ID
"""
for callback in self._on_session_deleted_callbacks:
try:
await callback(unified_msg_origin)
except Exception as e:
from astrbot.core import logger
logger.error(
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
updated_at = int(conv_v2.updated_at.timestamp())
return Conversation(
platform_id=conv_v2.platform_id,
user_id=conv_v2.user_id,
cid=conv_v2.conversation_id,
history=json.dumps(conv_v2.content or []),
title=conv_v2.title,
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
)
async def new_conversation(
self,
unified_msg_origin: str,
platform_id: str | None = None,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> str:
async def new_conversation(self, unified_msg_origin: str) -> str:
"""新建对话,并将当前会话的对话转移到新对话
Args:
@@ -82,23 +46,11 @@ class ConversationManager:
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not platform_id:
# 如果没有提供 platform_id则从 unified_msg_origin 中解析
parts = unified_msg_origin.split(":")
if len(parts) >= 3:
platform_id = parts[0]
if not platform_id:
platform_id = "unknown"
conv = await self.db.create_conversation(
user_id=unified_msg_origin,
platform_id=platform_id,
content=content,
title=title,
persona_id=persona_id,
)
self.session_conversations[unified_msg_origin] = conv.conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
@@ -108,10 +60,10 @@ class ConversationManager:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str | None = None
self, unified_msg_origin: str, conversation_id: str = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
@@ -119,29 +71,13 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid == conversation_id:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
del self.session_conversations[unified_msg_origin]
sp.put("session_conversation", self.session_conversations)
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
# 触发会话删除回调(级联清理)
await self._trigger_session_deleted(unified_msg_origin)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
"""获取会话当前的对话 ID
Args:
@@ -149,19 +85,14 @@ class ConversationManager:
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
ret = self.session_conversations.get(unified_msg_origin, None)
if not ret:
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
if ret:
self.session_conversations[unified_msg_origin] = ret
return ret
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(
self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation | None:
) -> Conversation:
"""获取会话的对话
Args:
@@ -170,74 +101,27 @@ class ConversationManager:
Returns:
conversation (Conversation): 对话对象
"""
conv = await self.db.get_conversation_by_id(cid=conversation_id)
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
conv = await self.db.get_conversation_by_id(cid=conversation_id)
conv_res = None
if conv:
conv_res = self._convert_conv_from_v2_to_v1(conv)
return conv_res
return self.db.get_conversation_by_user_id(
unified_msg_origin, conversation_id
)
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(
self, unified_msg_origin: str | None = None, platform_id: str | None = None
) -> List[Conversation]:
"""获取对话列表
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversations (List[Conversation]): 对话对象列表
"""
convs = await self.db.get_conversations(
user_id=unified_msg_origin, platform_id=platform_id
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res
async def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[Conversation], int]:
"""获取过滤后的对话列表
Args:
page (int): 页码, 默认为 1
page_size (int): 每页大小, 默认为 20
platform_ids (list[str]): 平台 ID 列表, 可选
search_query (str): 搜索查询字符串, 可选
Returns:
conversations (list[Conversation]): 对话对象列表
"""
convs, cnt = await self.db.get_filtered_conversations(
page=page,
page_size=page_size,
platform_ids=platform_ids,
search_query=search_query,
**kwargs,
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res, cnt
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(
self,
unified_msg_origin: str,
conversation_id: str | None = None,
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
):
"""更新会话的对话
@@ -246,55 +130,40 @@ class ConversationManager:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if not conversation_id:
# 如果没有提供 conversation_id则获取当前的
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
if conversation_id:
await self.db.update_conversation(
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
title=title,
persona_id=persona_id,
content=history,
history=json.dumps(history),
)
async def update_conversation_title(
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
):
async def update_conversation_title(self, unified_msg_origin: str, title: str):
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
Deprecated:
Use `update_conversation` with `title` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
title=title,
)
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin, cid=conversation_id, title=title
)
async def update_conversation_persona_id(
self,
unified_msg_origin: str,
persona_id: str,
conversation_id: str | None = None,
self, unified_msg_origin: str, persona_id: str
):
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
Deprecated:
Use `update_conversation` with `persona_id` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
persona_id=persona_id,
)
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
)
async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10

View File

@@ -15,27 +15,22 @@ import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config, html_renderer
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, sp
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
class AstrBotCoreLifecycle:
@@ -52,23 +47,12 @@ class AstrBotCoreLifecycle:
self.db = db # 初始化数据库
# 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "")
if proxy_config != "":
os.environ["https_proxy"] = proxy_config
os.environ["http_proxy"] = proxy_config
logger.debug(f"Using proxy: {proxy_config}")
# 设置 no_proxy
no_proxy_list = self.astrbot_config.get("no_proxy", [])
os.environ["no_proxy"] = ",".join(no_proxy_list)
else:
# 清空代理环境变量
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
logger.debug("HTTP proxy cleared")
if self.astrbot_config.get("http_proxy", ""):
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
if proxy := os.environ.get("https_proxy"):
logger.debug(f"Using proxy: {proxy}")
os.environ["no_proxy"] = "localhost"
async def initialize(self):
"""
@@ -82,36 +66,11 @@ class AstrBotCoreLifecycle:
else:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
await self.db.initialize()
await html_renderer.initialize()
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
)
# 4.5 to 4.6 migration for umop_config_router
try:
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
except Exception as e:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列
self.event_queue = Queue()
# 初始化人格管理器
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
await self.persona_mgr.initialize()
# 初始化供应商管理器
self.provider_manager = ProviderManager(
self.astrbot_config_mgr, self.db, self.persona_mgr
)
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
@@ -119,12 +78,6 @@ class AstrBotCoreLifecycle:
# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
# 初始化平台消息历史管理器
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
@@ -133,10 +86,6 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.platform_message_history_manager,
self.persona_mgr,
self.astrbot_config_mgr,
self.kb_manager,
)
# 初始化插件管理器
@@ -148,24 +97,23 @@ class AstrBotCoreLifecycle:
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
await self.kb_manager.initialize()
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
self.pipeline_scheduler = PipelineScheduler(
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
self.event_bus = EventBus(
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
)
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
# 记录启动时间
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: list[asyncio.Task] = []
self.curr_tasks: List[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
@@ -250,7 +198,6 @@ class AstrBotCoreLifecycle:
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
@@ -266,51 +213,17 @@ class AstrBotCoreLifecycle:
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def load_platform(self) -> list[asyncio.Task]:
def load_platform(self) -> List[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(
platform_inst.run(),
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
)
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
)
return tasks
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
"""加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
mapping = {}
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
mapping[conf_id] = scheduler
return mapping
async def reload_pipeline_scheduler(self, conf_id: str):
"""重新加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
if not ab_config:
raise ValueError(f"配置文件 {conf_id} 不存在")
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
self.pipeline_scheduler_mapping[conf_id] = scheduler

View File

@@ -1,20 +1,7 @@
import abc
import datetime
import typing as T
from deprecated import deprecated
from dataclasses import dataclass
from astrbot.core.db.po import (
Stats,
PlatformStat,
ConversationV2,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
)
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from typing import List, Dict, Any, Tuple
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@dataclass
@@ -23,278 +10,152 @@ class BaseDatabase(abc.ABC):
数据库基类
"""
DATABASE_URL = ""
def __init__(self) -> None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
)
async def initialize(self):
"""初始化数据库连接"""
pass
@asynccontextmanager
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
"""Get a database session."""
if not self.inited:
await self.initialize()
self.inited = True
async with self.AsyncSessionLocal() as session:
yield session
def insert_base_metrics(self, metrics: dict):
"""插入基础指标数据"""
self.insert_platform_metrics(metrics["platform_stats"])
self.insert_plugin_metrics(metrics["plugin_stats"])
self.insert_command_metrics(metrics["command_stats"])
self.insert_llm_metrics(metrics["llm_stats"])
@abc.abstractmethod
def insert_platform_metrics(self, metrics: dict):
"""插入平台指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_plugin_metrics(self, metrics: dict):
"""插入插件指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_command_metrics(self, metrics: dict):
"""插入指令指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_llm_metrics(self, metrics: dict):
"""插入 LLM 指标数据"""
raise NotImplementedError
@abc.abstractmethod
def update_llm_history(self, session_id: str, content: str, provider_type: str):
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
raise NotImplementedError
@abc.abstractmethod
def get_llm_history(
self, session_id: str = None, provider_type: str = None
) -> List[LLMHistory]:
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_total_message_count(self) -> int:
"""获取总消息数"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据(合并)"""
raise NotImplementedError
# New methods in v4.0.0
@abc.abstractmethod
def insert_atri_vision_data(self, vision_data: ATRIVision):
"""插入 ATRI 视觉数据"""
raise NotImplementedError
@abc.abstractmethod
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime.datetime | None = None,
) -> None:
"""Insert a new platform statistic record."""
...
def get_atri_vision_data(self) -> List[ATRIVision]:
"""获取 ATRI 视觉数据"""
raise NotImplementedError
@abc.abstractmethod
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
...
def get_atri_vision_data_by_path_or_id(
self, url_or_path: str, id: str
) -> ATRIVision:
"""通过 url 或 path 获取 ATRI 视觉数据"""
raise NotImplementedError
@abc.abstractmethod
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
...
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
"""通过 user_id 和 cid 获取 Conversation"""
raise NotImplementedError
@abc.abstractmethod
async def get_conversations(
self, user_id: str | None = None, platform_id: str | None = None
) -> list[ConversationV2]:
"""Get all conversations for a specific user and platform_id(optional).
content is not included in the result.
"""
...
def new_conversation(self, user_id: str, cid: str):
"""新建 Conversation"""
raise NotImplementedError
@abc.abstractmethod
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
"""Get a specific conversation by its ID."""
...
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
async def get_all_conversations(
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新 Conversation"""
raise NotImplementedError
@abc.abstractmethod
def delete_conversation(self, user_id: str, cid: str):
"""删除 Conversation"""
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
"""更新 Conversation 标题"""
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID"""
raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> list[ConversationV2]:
"""Get all conversations with pagination."""
...
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
async def get_filtered_conversations(
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[ConversationV2], int]:
"""Get conversations filtered by platform IDs and search query."""
...
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
@abc.abstractmethod
async def create_conversation(
self,
user_id: str,
platform_id: str,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
cid: str | None = None,
created_at: datetime.datetime | None = None,
updated_at: datetime.datetime | None = None,
) -> ConversationV2:
"""Create a new conversation."""
...
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
@abc.abstractmethod
async def update_conversation(
self,
cid: str,
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
) -> None:
"""Update a conversation's history."""
...
@abc.abstractmethod
async def delete_conversation(self, cid: str) -> None:
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def delete_conversations_by_user_id(self, user_id: str) -> None:
"""Delete all conversations for a specific user."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,
platform_id: str,
user_id: str,
content: dict,
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
"""Insert a new platform message history record."""
...
@abc.abstractmethod
async def delete_platform_message_offset(
self, platform_id: str, user_id: str, offset_sec: int = 86400
) -> None:
"""Delete platform message history records older than the specified offset."""
...
@abc.abstractmethod
async def get_platform_message_history(
self,
platform_id: str,
user_id: str,
page: int = 1,
page_size: int = 20,
) -> list[PlatformMessageHistory]:
"""Get platform message history for a specific user."""
...
@abc.abstractmethod
async def insert_attachment(
self,
path: str,
type: str,
mime_type: str,
):
"""Insert a new attachment record."""
...
@abc.abstractmethod
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
"""Get an attachment by its ID."""
...
@abc.abstractmethod
async def insert_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""Insert a new persona record."""
...
@abc.abstractmethod
async def get_persona_by_id(self, persona_id: str) -> Persona:
"""Get a persona by its ID."""
...
@abc.abstractmethod
async def get_personas(self) -> list[Persona]:
"""Get all personas for a specific bot."""
...
@abc.abstractmethod
async def update_persona(
self,
persona_id: str,
system_prompt: str | None = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona | None:
"""Update a persona's system prompt or begin dialogs."""
...
@abc.abstractmethod
async def delete_persona(self, persona_id: str) -> None:
"""Delete a persona by its ID."""
...
@abc.abstractmethod
async def insert_preference_or_update(
self, scope: str, scope_id: str, key: str, value: dict
) -> Preference:
"""Insert a new preference record."""
...
@abc.abstractmethod
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
"""Get a preference by scope ID and key."""
...
@abc.abstractmethod
async def get_preferences(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""Get all preferences for a specific scope ID or key."""
...
@abc.abstractmethod
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
"""Remove a preference by scope ID and key."""
...
@abc.abstractmethod
async def clear_preferences(self, scope: str, scope_id: str) -> None:
"""Clear all preferences for a specific scope ID."""
...
# @abc.abstractmethod
# async def insert_llm_message(
# self,
# cid: str,
# role: str,
# content: list,
# tool_calls: list = None,
# tool_call_id: str = None,
# parent_id: str = None,
# ) -> LLMMessage:
# """Insert a new LLM message into the conversation."""
# ...
# @abc.abstractmethod
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation."""
# ...
@abc.abstractmethod
async def get_session_conversations(
self,
page: int = 1,
page_size: int = 20,
search_query: str | None = None,
platform: str | None = None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
...
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError

View File

@@ -1,67 +0,0 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger, sp
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
migration_persona_data,
migration_preferences,
)
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
"""
检查是否需要进行数据库迁移
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4则需要进行迁移。
"""
# 仅当 data 目录下存在旧版本数据data_v3.db 文件)时才考虑迁移
data_dir = get_astrbot_data_path()
data_v3_db = os.path.join(data_dir, "data_v3.db")
if not os.path.exists(data_v3_db):
return False
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_v4"
)
if migration_done:
return False
return True
async def do_migration_v4(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
astrbot_config: AstrBotConfig,
) -> None:
"""
执行数据库迁移
迁移旧的 webchat_conversation 表到新的 conversation 表。
迁移旧的 platform 到新的 platform_stats 表。
"""
if not await check_migration_needed_v4(db_helper):
return
logger.info("开始执行数据库迁移...")
# 执行会话表迁移
await migration_conversation_table(db_helper, platform_id_map)
# 执行人格数据迁移
await migration_persona_data(db_helper, astrbot_config)
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移
await migration_preferences(db_helper, platform_id_map)
# 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map)
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_v4", True)
logger.info("数据库迁移完成。")

View File

@@ -1,338 +0,0 @@
import json
import datetime
from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from .shared_preferences_v3 import sp as sp_v3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
2. 迁移旧的 platform 到新的 platform_stats 表。
"""
def get_platform_id(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_id", old_platform_name)
def get_platform_type(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_type", old_platform_name)
async def migration_conversation_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
platform_id_map, session.platform_name
)
session.platform_id = platform_id # 更新平台名称为新的 ID
conv_v2 = ConversationV2(
user_id=str(session),
content=json.loads(conv.history) if conv.history else [],
platform_id=platform_id,
title=conv.title,
persona_id=conv.persona_id,
conversation_id=conv.cid,
created_at=datetime.datetime.fromtimestamp(conv.created_at),
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
)
dbsession.add(conv_v2)
except Exception as e:
logger.error(
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
async def migration_platform_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
secs_from_2023_4_10_to_now = (
datetime.datetime.now(datetime.timezone.utc)
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
).total_seconds()
offset_sec = int(secs_from_2023_4_10_to_now)
logger.info(f"迁移旧平台数据offset_sec: {offset_sec} 秒。")
stats = db_helper_v3.get_base_stats(offset_sec=offset_sec)
logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...")
platform_stats_v3 = stats.platform
if not platform_stats_v3:
logger.info("没有找到旧平台数据,跳过迁移。")
return
first_time_stamp = platform_stats_v3[0].timestamp
end_time_stamp = platform_stats_v3[-1].timestamp
start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时
end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时
idx = 0
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
total_buckets = (end_time - start_time) // 3600
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
if bucket_idx % 500 == 0:
progress = int((bucket_idx + 1) / total_buckets * 100)
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
cnt = 0
while (
idx < len(platform_stats_v3)
and platform_stats_v3[idx].timestamp < bucket_end
):
cnt += platform_stats_v3[idx].count
idx += 1
if cnt == 0:
continue
platform_id = get_platform_id(
platform_id_map, platform_stats_v3[idx].name
)
platform_type = get_platform_type(
platform_id_map, platform_stats_v3[idx].name
)
try:
await dbsession.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": datetime.datetime.fromtimestamp(
bucket_end, tz=datetime.timezone.utc
),
"platform_id": platform_id,
"platform_type": platform_type,
"count": cnt,
},
)
except Exception:
logger.error(
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
exc_info=True,
)
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
async def migration_webchat_data(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
continue
platform_id = "webchat"
history = json.loads(conv.history) if conv.history else []
for msg in history:
type_ = msg.get("type") # user type, "bot" or "user"
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=conv.cid, # we use conv.cid as user_id for webchat
content=msg,
sender_id=type_,
sender_name=type_,
)
dbsession.add(new_history)
except Exception:
logger.error(
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
async def migration_persona_data(
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
):
"""
迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
"""
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
total_personas = len(v3_persona_config)
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
for idx, persona in enumerate(v3_persona_config):
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
progress = int((idx + 1) / total_personas * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
try:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
mood_prompt = ""
user_turn = True
for mood_dialog in mood_imitation_dialogs:
if user_turn:
mood_prompt += f"A: {mood_dialog}\n"
else:
mood_prompt += f"B: {mood_dialog}\n"
user_turn = not user_turn
system_prompt = persona.get("prompt", "")
if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
persona_new = await db_helper.insert_persona(
persona_id=persona["name"],
system_prompt=system_prompt,
begin_dialogs=begin_dialogs,
)
logger.info(
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
async def migration_preferences(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
# 1. global scope migration
keys = [
"inactivated_llm_tools",
"inactivated_plugins",
"curr_provider",
"curr_provider_tts",
"curr_provider_stt",
"alter_cmd",
]
for key in keys:
value = sp_v3.get(key)
if value is not None:
await sp.put_async("global", "global", key, value)
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
# 2. umo scope migration
session_conversation = sp_v3.get("session_conversation", default={})
for umo, conversation_id in session_conversation.items():
if not umo or not conversation_id:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
session_service_config = sp_v3.get("session_service_config", default={})
for umo, config in session_service_config.items():
if not umo or not config:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_service_config", config)
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
session_variables = sp_v3.get("session_variables", default={})
for umo, variables in session_variables.items():
if not umo or not variables:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_variables", variables)
except Exception as e:
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
session_provider_perf = sp_v3.get("session_provider_perf", default={})
for umo, perf in session_provider_perf.items():
if not umo or not perf:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
for provider_type, provider_id in perf.items():
await sp.put_async(
"umo", str(session), f"provider_perf_{provider_type}", provider_id
)
logger.info(
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
)
except Exception as e:
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)

View File

@@ -1,44 +0,0 @@
from astrbot.api import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.umop_config_router import UmopConfigRouter
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
abconf_data = acm.abconf_data
if not isinstance(abconf_data, dict):
# should be unreachable
logger.warning(
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
)
return
# 如果任何一项带有 umop则说明需要迁移
need_migration = False
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
need_migration = True
break
if not need_migration:
return
logger.info("Starting migration from version 4.5 to 4.6")
# extract umo->conf_id mapping
umo_to_conf_id = {}
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
umop_ls = conf_info.pop("umop")
if not isinstance(umop_ls, list):
continue
for umo in umop_ls:
if isinstance(umo, str) and umo not in umo_to_conf_id:
umo_to_conf_id[umo] = conf_id
# update the abconf data
await sp.global_put("abconf_mapping", abconf_data)
# update the umop config router
await ucr.update_routing_data(umo_to_conf_id)
logger.info("Migration from version 45 to 46 completed successfully")

View File

@@ -1,47 +0,0 @@
import json
import os
from typing import TypeVar
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
sp = SharedPreferences()

View File

@@ -1,494 +0,0 @@
import sqlite3
import time
from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
INIT_SQL = """
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';
"""
class SQLiteDatabase:
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path
sql = INIT_SQL
# 初始化数据库
self.conn = self._get_conn(self.db_path)
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
"""
PRAGMA table_info(webchat_conversation)
"""
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
"""
)
self.conn.commit()
if not has_persona_id:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
"""
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
conn = self._get_conn(self.db_path)
c = conn.cursor()
if params:
c.execute(sql, params)
c.close()
else:
c.execute(sql)
c.close()
conn.commit()
def insert_platform_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def insert_llm_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM platform
"""
+ where_clause
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform=platform)
def get_total_message_count(self) -> int:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT SUM(count) FROM platform
"""
)
res = c.fetchone()
c.close()
return res[0]
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT name, SUM(count), timestamp FROM platform
"""
+ where_clause
+ " GROUP BY name"
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
res = c.fetchone()
c.close()
if not res:
return
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
"""
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
""",
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
""",
(user_id,),
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
)
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新对话,并且同时更新时间"""
updated_at = int(time.time())
self._exec_sql(
"""
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
""",
(history, updated_at, user_id, cid),
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
""",
(title, user_id, cid),
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
""",
(persona_id, user_id, cid),
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
"""
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()

View File

@@ -1,237 +1,7 @@
import uuid
"""指标数据"""
from datetime import datetime, timezone
from dataclasses import dataclass, field
from sqlmodel import (
SQLModel,
Text,
JSON,
UniqueConstraint,
Field,
)
from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
"""This class represents the statistics of bot usage across different platforms.
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
platform_id: str = Field(nullable=False)
platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc.
count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"timestamp",
"platform_id",
"platform_type",
name="uix_platform_stats",
),
)
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
inner_conversation_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conversation_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
content: Optional[list] = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
title: Optional[str] = Field(default=None, max_length=255)
persona_id: Optional[str] = Field(default=None)
__table_args__ = (
UniqueConstraint(
"conversation_id",
name="uix_conversation_id",
),
)
class Persona(SQLModel, table=True):
"""Persona is a set of instructions for LLMs to follow.
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
tools: Optional[list] = Field(default=None, sa_type=JSON)
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"persona_id",
name="uix_persona_id",
),
)
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences"
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
scope: str = Field(nullable=False)
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
scope_id: str = Field(nullable=False)
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
key: str = Field(nullable=False)
value: dict = Field(sa_type=JSON, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"scope",
"scope_id",
"key",
name="uix_preference_scope_scope_id_key",
),
)
class PlatformMessageHistory(SQLModel, table=True):
"""This class represents the message history for a specific platform.
It is used to store messages that are not LLM-generated, such as user messages
or platform-specific messages.
"""
__tablename__ = "platform_message_history"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
sender_name: Optional[str] = Field(
default=None
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class Attachment(SQLModel, table=True):
"""This class represents attachments for messages in AstrBot.
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments"
inner_attachment_id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
attachment_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
mime_type: str = Field(nullable=False) # MIME type of the file
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"attachment_id",
name="uix_attachment_id",
),
)
@dataclass
class Conversation:
"""LLM 对话类
对于 WebChathistory 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
在 v4.0.0 版本及之后WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中,
"""
platform_id: str
user_id: str
cid: str
"""对话 ID, 是 uuid 格式的字符串"""
history: str = ""
"""字符串格式的对话列表。"""
title: str | None = ""
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
class Personality(TypedDict):
"""LLM 人格类。
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
# ====
# Deprecated, and will be removed in future versions.
# ====
from typing import List
@dataclass
@@ -243,6 +13,77 @@ class Platform:
timestamp: int
@dataclass
class Provider:
"""供应商使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Plugin:
"""插件使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Command:
"""命令使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Stats:
platform: list[Platform] = field(default_factory=list)
platform: List[Platform] = field(default_factory=list)
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
@dataclass
class LLMHistory:
"""LLM 聊天时持久化的信息"""
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision:
"""Deprecated"""
id: str
url_or_path: str
caption: str
is_meme: bool
keywords: List[str]
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';

View File

@@ -16,42 +16,14 @@ class BaseVecDB:
pass
@abc.abstractmethod
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> int:
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
...
@abc.abstractmethod
async def retrieve(
self,
query: str,
top_k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
"""
搜索最相似的文档。
Args:
@@ -72,6 +44,3 @@ class BaseVecDB:
bool: 删除是否成功
"""
...
@abc.abstractmethod
async def close(self): ...

View File

@@ -1,224 +1,59 @@
import aiosqlite
import os
import json
from datetime import datetime
from contextlib import asynccontextmanager
from sqlalchemy import Text, Column
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
from astrbot.core import logger
class BaseDocModel(SQLModel, table=False):
metadata = MetaData()
class Document(BaseDocModel, table=True):
"""SQLModel for documents table."""
__tablename__ = "documents" # type: ignore
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
doc_id: str = Field(nullable=False)
text: str = Field(nullable=False)
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
created_at: datetime | None = Field(default=None)
updated_at: datetime | None = Field(default=None)
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.engine: AsyncEngine | None = None
self.async_session_maker: sessionmaker | None = None
self.connection = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
await self.connect()
async with self.engine.begin() as conn: # type: ignore
# Create tables using SQLModel
await conn.run_sync(BaseDocModel.metadata.create_all)
try:
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
)
)
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN user_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
)
)
# Create indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
)
)
except BaseException:
pass
await conn.commit()
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
async def connect(self):
"""Connect to the SQLite database."""
if self.engine is None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.async_session_maker = sessionmaker(
self.engine, # type: ignore
class_=AsyncSession,
expire_on_commit=False,
) # type: ignore
self.connection = await aiosqlite.connect(self.db_path)
@asynccontextmanager
async def get_session(self):
"""Context manager for database sessions."""
async with self.async_session_maker() as session: # type: ignore
yield session
async def get_documents(
self,
metadata_filters: dict,
ids: list | None = None,
offset: int | None = 0,
limit: int | None = 100,
) -> list[dict]:
async def get_documents(self, metadata_filters: dict, ids: list = None):
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
ids (list | None): Optional list of document IDs to filter.
offset (int | None): Offset for pagination.
limit (int | None): Limit for pagination.
Returns:
list: The list of documents that match the filters.
list: The list of document IDs(primary key, not doc_id) that match the filters.
"""
if self.engine is None:
logger.warning(
"Database connection is not initialized, returning empty result"
)
return []
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
async with self.get_session() as session:
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
if ids is not None and len(ids) > 0:
valid_ids = [int(i) for i in ids if i != -1]
if valid_ids:
query = query.where(col(Document.id).in_(valid_ids))
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await session.execute(query)
documents = result.scalars().all()
return [self._document_to_dict(doc) for doc in documents]
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
"""Insert a single document and return its integer ID.
Args:
doc_id (str): The document ID (UUID string).
text (str): The document text.
metadata (dict): The document metadata.
Returns:
int: The integer ID of the inserted document.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(document)
await session.flush() # Flush to get the ID
return document.id # type: ignore
async def insert_documents_batch(
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
) -> list[int]:
"""Batch insert documents and return their integer IDs.
Args:
doc_ids (list[str]): List of document IDs (UUID strings).
texts (list[str]): List of document texts.
metadatas (list[dict]): List of document metadata.
Returns:
list[int]: List of integer IDs of the inserted documents.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
import json
documents = []
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
documents.append(document)
session.add(document)
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
"""Delete a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to delete.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
await session.delete(document)
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
@@ -227,91 +62,28 @@ class DocumentStorage:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data or None if not found.
dict: The document data.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
return self._document_to_dict(document)
return None
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Update a document by its doc_id.
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
"""Delete documents by their metadata filters.
Args:
metadata_filters (dict): The metadata filters to apply.
"""
if self.engine is None:
logger.warning(
"Database connection is not initialized, skipping delete operation"
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
return
async with self.get_session() as session:
async with session.begin():
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
documents = result.scalars().all()
for doc in documents:
await session.delete(doc)
async def count_documents(self, metadata_filters: dict | None = None) -> int:
"""Count documents in the database.
Args:
metadata_filters (dict | None): Metadata filters to apply.
Returns:
int: The count of documents.
"""
if self.engine is None:
logger.warning("Database connection is not initialized, returning 0")
return 0
async with self.get_session() as session:
query = select(func.count(col(Document.id)))
if metadata_filters:
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
count = result.scalar_one_or_none()
return count if count is not None else 0
await self.connection.commit()
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
@@ -319,38 +91,11 @@ class DocumentStorage:
Returns:
list: A list of user IDs.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = text(
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
)
result = await session.execute(query)
rows = result.fetchall()
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
return [row[0] for row in rows]
def _document_to_dict(self, document: Document) -> dict:
"""Convert a Document model to a dictionary.
Args:
document (Document): The document to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": document.id,
"doc_id": document.doc_id,
"text": document.text,
"metadata": document.metadata_,
"created_at": document.created_at.isoformat()
if isinstance(document.created_at, datetime)
else document.created_at,
"updated_at": document.updated_at.isoformat()
if isinstance(document.updated_at, datetime)
else document.updated_at,
}
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
@@ -359,8 +104,6 @@ class DocumentStorage:
Returns:
dict: The converted dictionary.
Note: This method is kept for backward compatibility but is no longer used internally.
"""
return {
"id": row[0],
@@ -373,7 +116,6 @@ class DocumentStorage:
async def close(self):
"""Close the connection to the SQLite database."""
if self.engine:
await self.engine.dispose()
self.engine = None
self.async_session_maker = None
if self.connection:
await self.connection.close()
self.connection = None

View File

@@ -9,7 +9,7 @@ import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str | None = None):
def __init__(self, dimension: int, path: str = None):
self.dimension = dimension
self.path = path
self.index = None
@@ -18,6 +18,7 @@ class EmbeddingStorage:
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
@@ -28,29 +29,12 @@ class EmbeddingStorage:
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
await self.save_index()
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
"""批量插入向量
Args:
vectors (np.ndarray): 要插入的向量数组
ids (list[int]): 向量的ID列表
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vectors.shape[1] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
)
self.index.add_with_ids(vectors, np.array(ids))
self.storage[id] = vector
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
@@ -62,22 +46,10 @@ class EmbeddingStorage:
Returns:
tuple: (距离, 索引)
"""
assert self.index is not None, "FAISS index is not initialized."
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def delete(self, ids: list[int]):
"""删除向量
Args:
ids (list[int]): 要删除的向量ID列表
"""
assert self.index is not None, "FAISS index is not initialized."
id_array = np.array(ids, dtype=np.int64)
self.index.remove_ids(id_array)
await self.save_index()
async def save_index(self):
"""保存索引

View File

@@ -1,12 +1,10 @@
import uuid
import time
import json
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
from astrbot import logger
class FaissVecDB(BaseVecDB):
@@ -19,7 +17,6 @@ class FaissVecDB(BaseVecDB):
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
rerank_provider: RerankProvider | None = None,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
@@ -29,14 +26,11 @@ class FaissVecDB(BaseVecDB):
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
@@ -45,64 +39,21 @@ class FaissVecDB(BaseVecDB):
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 使用 DocumentStorage 的方法插入文档
int_id = await self.document_storage.insert_document(str_id, content, metadata)
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[int]:
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
metadatas = metadatas or [{} for _ in contents]
ids = ids or [str(uuid.uuid4()) for _ in contents]
start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
end = time.time()
logger.debug(
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
)
# 使用 DocumentStorage 的批量插入方法
int_ids = await self.document_storage.insert_documents_batch(
ids, contents, metadatas
)
# 批量插入向量到 FAISS
vectors_array = np.array(vectors).astype("float32")
await self.embedding_storage.insert_batch(vectors_array, int_ids)
return int_ids
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def retrieve(
self,
query: str,
k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
) -> list[Result]:
"""
搜索最相似的文档。
@@ -111,7 +62,6 @@ class FaissVecDB(BaseVecDB):
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。
metadata_filters (dict): 元数据过滤器
Returns:
@@ -122,6 +72,7 @@ class FaissVecDB(BaseVecDB):
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
# TODO: rerank
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
@@ -132,7 +83,7 @@ class FaissVecDB(BaseVecDB):
)
if not fetched_docs:
return []
result_docs: list[Result] = []
result_docs = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
@@ -142,58 +93,25 @@ class FaissVecDB(BaseVecDB):
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
return result_docs[:k]
top_k_results = result_docs[:k]
if rerank and self.rerank_provider:
documents = [doc.data["text"] for doc in top_k_results]
reranked_results = await self.rerank_provider.rerank(query, documents)
reranked_results = sorted(
reranked_results, key=lambda x: x.relevance_score, reverse=True
)
top_k_results = [
top_k_results[reranked_result.index]
for reranked_result in reranked_results
]
return top_k_results
async def delete(self, doc_id: str):
async def delete(self, doc_id: int):
"""
删除一条文档chunk
删除一条文档
"""
# 获得对应的 int id
result = await self.document_storage.get_document_by_doc_id(doc_id)
int_id = result["id"] if result else None
if int_id is None:
return
# 使用 DocumentStorage 的删除方法
await self.document_storage.delete_document_by_doc_id(doc_id)
await self.embedding_storage.delete([int_id])
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
async def close(self):
await self.document_storage.close()
async def count_documents(self, metadata_filter: dict | None = None) -> int:
async def count_documents(self) -> int:
"""
计算文档数量
Args:
metadata_filter (dict | None): 元数据过滤器
"""
count = await self.document_storage.count_documents(
metadata_filters=metadata_filter or {}
)
return count
async def delete_documents(self, metadata_filters: dict):
"""
根据元数据过滤器删除文档
"""
docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters, offset=None, limit=None
)
doc_ids: list[int] = [doc["id"] for doc in docs]
await self.embedding_storage.delete(doc_ids)
await self.document_storage.delete_documents(metadata_filters=metadata_filters)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0

View File

@@ -16,32 +16,30 @@ from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class EventBus:
"""用于处理事件的分发和处理"""
"""事件总线: 用于处理事件的分发和处理
def __init__(
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
):
维护一个异步队列, 来接受各种消息事件
"""
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
async def dispatch(self):
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
while True:
event: AstrMessageEvent = await self.event_queue.get()
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
asyncio.create_task(scheduler.execute(event))
event: AstrMessageEvent = (
await self.event_queue.get()
) # 从事件队列中获取新的事件
self._print_event(event) # 打印日志
asyncio.create_task(
self.pipeline_scheduler.execute(event)
) # 创建新的异步任务来执行管道调度器的处理逻辑
def _print_event(self, event: AstrMessageEvent, conf_name: str):
def _print_event(self, event: AstrMessageEvent):
"""用于记录事件信息
Args:
@@ -50,10 +48,10 @@ class EventBus:
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name():
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else:
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
)

View File

@@ -23,12 +23,7 @@ class FileTokenService:
for token in expired_tokens:
self.staged_files.pop(token, None)
async def check_token_expired(self, file_token: str) -> bool:
async with self.lock:
await self._cleanup_expired_tokens()
return file_token not in self.staged_files
async def register_file(self, file_path: str, timeout: float | None = None) -> str:
async def register_file(self, file_path: str, timeout: float = None) -> str:
"""向令牌服务注册一个文件。
Args:

View File

@@ -22,7 +22,6 @@ class InitialLoader:
self.db = db
self.logger = logger
self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
@@ -36,18 +35,13 @@ class InitialLoader:
core_task = core_lifecycle.start()
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
)
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
coro = self.dashboard_server.run()
if coro:
# 启动核心任务和仪表板服务器
task = asyncio.gather(core_task, coro)
else:
task = core_task
try:
await task # 整个AstrBot在这里运行
except asyncio.CancelledError:

View File

@@ -1,11 +0,0 @@
"""
文档分块模块
"""
from .base import BaseChunker
from .fixed_size import FixedSizeChunker
__all__ = [
"BaseChunker",
"FixedSizeChunker",
]

View File

@@ -1,24 +0,0 @@
"""文档分块器基类
定义了文档分块处理的抽象接口。
"""
from abc import ABC, abstractmethod
class BaseChunker(ABC):
"""分块器基类
所有分块器都应该继承此类并实现 chunk 方法。
"""
@abstractmethod
async def chunk(self, text: str, **kwargs) -> list[str]:
"""将文本分块
Args:
text: 输入文本
Returns:
list[str]: 分块后的文本列表
"""

View File

@@ -1,57 +0,0 @@
"""固定大小分块器
按照固定的字符数将文本分块,支持重叠区域。
"""
from .base import BaseChunker
class FixedSizeChunker(BaseChunker):
"""固定大小分块器
按照固定的字符数分块,并支持块之间的重叠。
"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
"""初始化分块器
Args:
chunk_size: 块的大小(字符数)
chunk_overlap: 块之间的重叠字符数
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
async def chunk(self, text: str, **kwargs) -> list[str]:
"""固定大小分块
Args:
text: 输入文本
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
Returns:
list[str]: 分块后的文本列表
"""
chunk_size = kwargs.get("chunk_size", self.chunk_size)
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = start + chunk_size
chunk = text[start:end]
if chunk:
chunks.append(chunk)
# 移动窗口,保留重叠部分
start = end - chunk_overlap
# 防止无限循环: 如果重叠过大,直接移到end
if start >= end or chunk_overlap >= chunk_size:
start = end
return chunks

View File

@@ -1,155 +0,0 @@
from collections.abc import Callable
from .base import BaseChunker
class RecursiveCharacterChunker(BaseChunker):
def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 100,
length_function: Callable[[str], int] = len,
is_separator_regex: bool = False,
separators: list[str] | None = None,
):
"""
初始化递归字符文本分割器
Args:
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
length_function: 计算文本长度的函数
is_separator_regex: 分隔符是否为正则表达式
separators: 用于分割文本的分隔符列表,按优先级排序
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
self.is_separator_regex = is_separator_regex
# 默认分隔符列表,按优先级从高到低
self.separators = separators or [
"\n\n", # 段落
"\n", # 换行
"", # 中文句子
"", # 中文逗号
". ", # 句子
", ", # 逗号分隔
" ", # 单词
"", # 字符
]
async def chunk(self, text: str, **kwargs) -> list[str]:
"""
递归地将文本分割成块
Args:
text: 要分割的文本
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
Returns:
分割后的文本块列表
"""
if not text:
return []
overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
chunk_size = kwargs.get("chunk_size", self.chunk_size)
text_length = self.length_function(text)
if text_length <= chunk_size:
return [text]
for separator in self.separators:
if separator == "":
return self._split_by_character(text, chunk_size, overlap)
if separator in text:
splits = text.split(separator)
# 重新添加分隔符(除了最后一个片段)
splits = [s + separator for s in splits[:-1]] + [splits[-1]]
splits = [s for s in splits if s]
if len(splits) == 1:
continue
# 递归合并分割后的文本块
final_chunks = []
current_chunk = []
current_chunk_length = 0
for split in splits:
split_length = self.length_function(split)
# 如果单个分割部分已经超过了chunk_size需要递归分割
if split_length > chunk_size:
# 先处理当前积累的块
if current_chunk:
combined_text = "".join(current_chunk)
final_chunks.extend(
await self.chunk(
combined_text,
chunk_size=chunk_size,
chunk_overlap=overlap,
)
)
current_chunk = []
current_chunk_length = 0
# 递归分割过大的部分
final_chunks.extend(
await self.chunk(
split, chunk_size=chunk_size, chunk_overlap=overlap
)
)
# 如果添加这部分会使当前块超过chunk_size
elif current_chunk_length + split_length > chunk_size:
# 合并当前块并添加到结果中
combined_text = "".join(current_chunk)
final_chunks.append(combined_text)
# 处理重叠部分
overlap_start = max(0, len(combined_text) - overlap)
if overlap_start > 0:
overlap_text = combined_text[overlap_start:]
current_chunk = [overlap_text, split]
current_chunk_length = (
self.length_function(overlap_text) + split_length
)
else:
current_chunk = [split]
current_chunk_length = split_length
else:
# 添加到当前块
current_chunk.append(split)
current_chunk_length += split_length
# 处理剩余的块
if current_chunk:
final_chunks.append("".join(current_chunk))
return final_chunks
return [text]
def _split_by_character(
self, text: str, chunk_size: int | None = None, overlap: int | None = None
) -> list[str]:
"""
按字符级别分割文本
Args:
text: 要分割的文本
Returns:
分割后的文本块列表
"""
chunk_size = chunk_size or self.chunk_size
overlap = overlap or self.chunk_overlap
result = []
for i in range(0, len(text), chunk_size - overlap):
end = min(i + chunk_size, len(text))
result.append(text[i:end])
if end == len(text):
break
return result

View File

@@ -1,299 +0,0 @@
from contextlib import asynccontextmanager
from pathlib import Path
from sqlmodel import col, desc
from sqlalchemy import text, func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
KBMedia,
KnowledgeBase,
)
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
class KBSQLiteDatabase:
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(BaseKBModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)"
)
)
await session.commit()
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
"""根据 ID 获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
"""根据名称获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(desc(KnowledgeBase.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KnowledgeBase.id)))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
"""根据 ID 获取文档"""
async with self.get_db() as session:
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.get_db() as session:
stmt = (
select(KBDocument)
.where(col(KBDocument.kb_id) == kb_id)
.offset(offset)
.limit(limit)
.order_by(desc(KBDocument.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KBDocument.id))).where(
col(KBDocument.kb_id) == kb_id
)
result = await session.execute(stmt)
return result.scalar() or 0
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
async with self.get_db() as session:
stmt = (
select(KBDocument, KnowledgeBase)
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
.where(col(KBDocument.doc_id) == doc_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
return {
"document": row[0],
"knowledge_base": row[1],
}
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session:
async with session.begin():
# 删除文档记录
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
await session.execute(delete_stmt)
await session.commit()
# 在 vec db 中删除相关向量
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
"""根据 ID 获取多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
"""更新知识库统计信息"""
chunk_cnt = await vec_db.count_documents()
async with self.get_db() as session:
async with session.begin():
update_stmt = (
update(KnowledgeBase)
.where(col(KnowledgeBase.kb_id) == kb_id)
.values(
doc_count=select(func.count(col(KBDocument.id)))
.where(col(KBDocument.kb_id) == kb_id)
.scalar_subquery(),
chunk_count=chunk_cnt,
)
)
await session.execute(update_stmt)
await session.commit()

View File

@@ -1,348 +0,0 @@
import uuid
import aiofiles
import json
from pathlib import Path
from .models import KnowledgeBase, KBDocument, KBMedia
from .kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.provider.manager import ProviderManager
from .parsers.util import select_parser
from .chunking.base import BaseChunker
from astrbot.core import logger
class KBHelper:
vec_db: BaseVecDB
kb: KnowledgeBase
def __init__(
self,
kb_db: KBSQLiteDatabase,
kb: KnowledgeBase,
provider_manager: ProviderManager,
kb_root_dir: str,
chunker: BaseChunker,
):
self.kb_db = kb_db
self.kb = kb
self.prov_mgr = provider_manager
self.kb_root_dir = kb_root_dir
self.chunker = chunker
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self):
await self._ensure_vec_db()
async def get_ep(self) -> EmbeddingProvider:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
self.kb.embedding_provider_id
) # type: ignore
if not ep:
raise ValueError(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
)
return ep
async def get_rp(self) -> RerankProvider | None:
if not self.kb.rerank_provider_id:
return None
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
self.kb.rerank_provider_id
) # type: ignore
if not rp:
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
)
return rp
async def _ensure_vec_db(self) -> FaissVecDB:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep = await self.get_ep()
rp = await self.get_rp()
vec_db = FaissVecDB(
doc_store_path=str(self.kb_dir / "doc.db"),
index_store_path=str(self.kb_dir / "index.faiss"),
embedding_provider=ep,
rerank_provider=rp,
)
await vec_db.initialize()
self.vec_db = vec_db
return vec_db
async def delete_vec_db(self):
"""删除知识库的向量数据库和所有相关文件"""
import shutil
await self.terminate()
if self.kb_dir.exists():
shutil.rmtree(self.kb_dir)
async def terminate(self):
if self.vec_db:
await self.vec_db.close()
async def upload_document(
self,
file_name: str,
file_content: bytes,
file_type: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
Args:
progress_callback: 进度回调函数,接收参数 (stage, current, total)
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
- current: 当前进度
- total: 总数
"""
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
media_paths: list[Path] = []
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
# async with aiofiles.open(file_path, "wb") as f:
# await f.write(file_content)
try:
# 阶段1: 解析文档
if progress_callback:
await progress_callback("parsing", 0, 100)
parser = await select_parser(f".{file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
if progress_callback:
await progress_callback("parsing", 100, 100)
# 保存媒体文件
saved_media = []
for media_item in media_items:
media = await self._save_media(
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 阶段2: 分块
if progress_callback:
await progress_callback("chunking", 0, 100)
chunks_text = await self.chunker.chunk(
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
contents = []
metadatas = []
for idx, chunk_text in enumerate(chunks_text):
contents.append(chunk_text)
metadatas.append(
{
"kb_id": self.kb.kb_id,
"kb_doc_id": doc_id,
"chunk_index": idx,
}
)
if progress_callback:
await progress_callback("chunking", 100, 100)
# 阶段3: 生成向量(带进度回调)
async def embedding_progress_callback(current, total):
if progress_callback:
await progress_callback("embedding", current, total)
await self.vec_db.insert_batch(
contents=contents,
metadatas=metadatas,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=embedding_progress_callback,
)
# 保存文档的元数据
doc = KBDocument(
doc_id=doc_id,
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
# file_path=str(file_path),
file_path="",
chunk_count=len(chunks_text),
media_count=0,
)
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
vec_db: FaissVecDB = self.vec_db # type: ignore
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
await self.refresh_kb()
await self.refresh_document(doc_id)
return doc
except Exception as e:
logger.error(f"上传文档失败: {e}")
# if file_path.exists():
# file_path.unlink()
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
raise e
async def list_documents(
self, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
return docs
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取单个文档"""
doc = await self.kb_db.get_document_by_id(doc_id)
return doc
async def delete_document(self, doc_id: str):
"""删除单个文档及其相关数据"""
await self.kb_db.delete_document_by_id(
doc_id=doc_id,
vec_db=self.vec_db, # type: ignore
)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
async def delete_chunk(self, chunk_id: str, doc_id: str):
"""删除单个文本块及其相关数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
await vec_db.delete(chunk_id)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
await self.refresh_document(doc_id)
async def refresh_kb(self):
if self.kb:
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
if kb:
self.kb = kb
async def refresh_document(self, doc_id: str) -> None:
"""更新文档的元数据"""
doc = await self.get_document(doc_id)
if not doc:
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
doc.chunk_count = chunk_count
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
await session.commit()
await session.refresh(doc)
async def get_chunks_by_doc_id(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[dict]:
"""获取文档的所有块及其元数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
chunks = await vec_db.document_storage.get_documents(
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
)
result = []
for chunk in chunks:
chunk_md = json.loads(chunk["metadata"])
result.append(
{
"chunk_id": chunk["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": chunk_md["kb_id"],
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
"char_count": len(chunk["text"]),
}
)
return result
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
"""获取文档的块数量"""
vec_db: FaissVecDB = self.vec_db # type: ignore
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
return count
async def _save_media(
self,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=self.kb.kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media

View File

@@ -1,287 +0,0 @@
import traceback
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.manager import ProviderManager
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.sparse_retriever import SparseRetriever
from .retrieval.rank_fusion import RankFusion
from .kb_db_sqlite import KBSQLiteDatabase
# from .chunking.fixed_size import FixedSizeChunker
from .chunking.recursive import RecursiveCharacterChunker
from .kb_helper import KBHelper
from .models import KnowledgeBase
FILES_PATH = "data/knowledge_base"
DB_PATH = Path(FILES_PATH) / "kb.db"
"""Knowledge Base storage root directory"""
CHUNKER = RecursiveCharacterChunker()
class KnowledgeBaseManager:
kb_db: KBSQLiteDatabase
retrieval_manager: RetrievalManager
def __init__(
self,
provider_manager: ProviderManager,
):
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
self.provider_manager = provider_manager
self._session_deleted_callback_registered = False
self.kb_insts: dict[str, KBHelper] = {}
async def initialize(self):
"""初始化知识库模块"""
try:
logger.info("正在初始化知识库模块...")
# 初始化数据库
await self._init_kb_database()
# 初始化检索管理器
sparse_retriever = SparseRetriever(self.kb_db)
rank_fusion = RankFusion(self.kb_db)
self.retrieval_manager = RetrievalManager(
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_db,
)
await self.load_kbs()
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self):
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
async def load_kbs(self):
"""加载所有知识库实例"""
kb_records = await self.kb_db.list_kbs()
for record in kb_records:
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=record,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
await kb_helper.initialize()
self.kb_insts[record.kb_id] = kb_helper
async def create_kb(
self,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper:
"""创建新的知识库实例"""
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
)
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=kb,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
await kb_helper.initialize()
self.kb_insts[kb.kb_id] = kb_helper
return kb_helper
async def get_kb(self, kb_id: str) -> KBHelper | None:
"""获取知识库实例"""
if kb_id in self.kb_insts:
return self.kb_insts[kb_id]
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
"""通过名称获取知识库实例"""
for kb_helper in self.kb_insts.values():
if kb_helper.kb.kb_name == kb_name:
return kb_helper
return None
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return False
await kb_helper.delete_vec_db()
async with self.kb_db.get_db() as session:
await session.delete(kb_helper.kb)
await session.commit()
self.kb_insts.pop(kb_id, None)
return True
async def list_kbs(self) -> list[KnowledgeBase]:
"""列出所有知识库实例"""
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
return kbs
async def update_kb(
self,
kb_id: str,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper | None:
"""更新知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return None
kb = kb_helper.kb
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
return kb_helper
async def retrieve(
self,
query: str,
kb_names: list[str],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> dict | None:
"""从指定知识库中检索相关内容"""
kb_ids = []
kb_id_helper_map = {}
for kb_name in kb_names:
if kb_helper := await self.get_kb_by_name(kb_name):
kb_ids.append(kb_helper.kb.kb_id)
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
if not kb_ids:
return {}
results = await self.retrieval_manager.retrieve(
query=query,
kb_ids=kb_ids,
kb_id_helper_map=kb_id_helper_map,
top_k_fusion=top_k_fusion,
top_m_final=top_m_final,
)
if not results:
return None
context_text = self._format_context(results)
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
"char_count": r.metadata.get("char_count", 0),
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
def _format_context(self, results: list[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
async def terminate(self):
"""终止所有知识库实例,关闭数据库连接"""
for kb_id, kb_helper in self.kb_insts.items():
try:
await kb_helper.terminate()
except Exception as e:
logger.error(f"关闭知识库 {kb_id} 失败: {e}")
self.kb_insts.clear()
# 关闭元数据数据库
if hasattr(self, "kb_db") and self.kb_db:
try:
await self.kb_db.close()
except Exception as e:
logger.error(f"关闭知识库元数据数据库失败: {e}")

View File

@@ -1,114 +0,0 @@
import uuid
from datetime import datetime, timezone
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
class BaseKBModel(SQLModel, table=False):
metadata = MetaData()
class KnowledgeBase(BaseKBModel, table=True):
"""知识库表
存储知识库的基本信息和统计数据。
"""
__tablename__ = "knowledge_bases" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
kb_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_name: str = Field(max_length=100, nullable=False)
description: str | None = Field(default=None, sa_type=Text)
emoji: str | None = Field(default="📚", max_length=10)
embedding_provider_id: str | None = Field(default=None, max_length=100)
rerank_provider_id: str | None = Field(default=None, max_length=100)
# 分块配置参数
chunk_size: int | None = Field(default=512, nullable=True)
chunk_overlap: int | None = Field(default=50, nullable=True)
# 检索配置参数
top_k_dense: int | None = Field(default=50, nullable=True)
top_k_sparse: int | None = Field(default=50, nullable=True)
top_m_final: int | None = Field(default=5, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
doc_count: int = Field(default=0, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"kb_name",
name="uix_kb_name",
),
)
class KBDocument(BaseKBModel, table=True):
"""文档表
存储上传到知识库的文档元数据。
"""
__tablename__ = "kb_documents" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
doc_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_id: str = Field(max_length=36, nullable=False, index=True)
doc_name: str = Field(max_length=255, nullable=False)
file_type: str = Field(max_length=20, nullable=False)
file_size: int = Field(nullable=False)
file_path: str = Field(max_length=512, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
media_count: int = Field(default=0, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class KBMedia(BaseKBModel, table=True):
"""多媒体资源表
存储从文档中提取的图片、视频等多媒体资源。
"""
__tablename__ = "kb_media" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
media_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
doc_id: str = Field(max_length=36, nullable=False, index=True)
kb_id: str = Field(max_length=36, nullable=False, index=True)
media_type: str = Field(max_length=20, nullable=False)
file_name: str = Field(max_length=255, nullable=False)
file_path: str = Field(max_length=512, nullable=False)
file_size: int = Field(nullable=False)
mime_type: str = Field(max_length=100, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))

View File

@@ -1,15 +0,0 @@
"""
文档解析器模块
"""
from .base import BaseParser, MediaItem, ParseResult
from .text_parser import TextParser
from .pdf_parser import PDFParser
__all__ = [
"BaseParser",
"MediaItem",
"ParseResult",
"TextParser",
"PDFParser",
]

View File

@@ -1,50 +0,0 @@
"""文档解析器基类和数据结构
定义了文档解析器的抽象接口和相关数据类。
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class MediaItem:
"""多媒体项
表示从文档中提取的多媒体资源。
"""
media_type: str # image, video
file_name: str
content: bytes
mime_type: str
@dataclass
class ParseResult:
"""解析结果
包含解析后的文本内容和提取的多媒体资源。
"""
text: str
media: list[MediaItem]
class BaseParser(ABC):
"""文档解析器基类
所有文档解析器都应该继承此类并实现 parse 方法。
"""
@abstractmethod
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文档
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果
"""

View File

@@ -1,25 +0,0 @@
import io
import os
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
ParseResult,
)
from markitdown_no_magika import MarkItDown, StreamInfo
class MarkitdownParser(BaseParser):
"""解析 docx, xls, xlsx 格式"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
md = MarkItDown(enable_plugins=False)
bio = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=os.path.splitext(file_name)[1].lower(),
filename=file_name,
)
result = md.convert(bio, stream_info=stream_info)
return ParseResult(
text=result.markdown,
media=[],
)

View File

@@ -1,100 +0,0 @@
"""PDF 文件解析器
支持解析 PDF 文件中的文本和图片资源。
"""
import io
from pypdf import PdfReader
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
MediaItem,
ParseResult,
)
class PDFParser(BaseParser):
"""PDF 文档解析器
提取 PDF 中的文本内容和嵌入的图片资源。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析 PDF 文件
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 包含文本和图片的解析结果
"""
pdf_file = io.BytesIO(file_content)
reader = PdfReader(pdf_file)
text_parts = []
media_items = []
# 提取文本
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
# 提取图片
image_counter = 0
for page_num, page in enumerate(reader.pages):
try:
# 安全检查 Resources
if "/Resources" not in page:
continue
resources = page["/Resources"]
if not resources or "/XObject" not in resources: # type: ignore
continue
xobjects = resources["/XObject"].get_object() # type: ignore
if not xobjects:
continue
for obj_name in xobjects:
try:
obj = xobjects[obj_name]
if obj.get("/Subtype") != "/Image":
continue
# 提取图片数据
image_data = obj.get_data()
# 确定格式
filter_type = obj.get("/Filter", "")
if filter_type == "/DCTDecode":
ext = "jpg"
mime_type = "image/jpeg"
elif filter_type == "/FlateDecode":
ext = "png"
mime_type = "image/png"
else:
ext = "png"
mime_type = "image/png"
image_counter += 1
media_items.append(
MediaItem(
media_type="image",
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
content=image_data,
mime_type=mime_type,
)
)
except Exception:
# 单个图片提取失败不影响整体
continue
except Exception:
# 页面处理失败不影响其他页面
continue
full_text = "\n\n".join(text_parts)
return ParseResult(text=full_text, media=media_items)

View File

@@ -1,41 +0,0 @@
"""文本文件解析器
支持解析 TXT 和 Markdown 文件。
"""
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
class TextParser(BaseParser):
"""TXT/MD 文本解析器
支持多种字符编码的自动检测。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文本文件
尝试使用多种编码解析文件内容。
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果,不包含多媒体资源
Raises:
ValueError: 如果无法解码文件
"""
# 尝试多种编码
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
try:
text = file_content.decode(encoding)
break
except UnicodeDecodeError:
continue
else:
raise ValueError(f"无法解码文件: {file_name}")
# 文本文件无多媒体资源
return ParseResult(text=text, media=[])

View File

@@ -1,13 +0,0 @@
from .base import BaseParser
async def select_parser(ext: str) -> BaseParser:
if ext in {".md", ".txt", ".markdown", ".xlsx", ".docx", ".xls"}:
from .markitdown_parser import MarkitdownParser
return MarkitdownParser()
elif ext == ".pdf":
from .pdf_parser import PDFParser
return PDFParser()
raise ValueError(f"暂时不支持的文件格式: {ext}")

View File

@@ -1,16 +0,0 @@
"""
检索模块
"""
from .manager import RetrievalManager, RetrievalResult
from .sparse_retriever import SparseRetriever, SparseResult
from .rank_fusion import RankFusion, FusedResult
__all__ = [
"RetrievalManager",
"RetrievalResult",
"SparseRetriever",
"SparseResult",
"RankFusion",
"FusedResult",
]

View File

@@ -1,767 +0,0 @@
———
》),
)÷(1-
”,
)、
:
&
*
一一
~~~~
.
.一
./
--
=″
[⑤]]
[①D]
ng昉
//
[②e]
[②g]
}
,也
[①⑥]
[②B]
[①a]
[④a]
[①③]
[③h]
③]
[②b]
×××
[①⑧]
[⑤b]
[②c]
[④b]
[②③]
[③a]
[④c]
[①⑤]
[①⑦]
[①g]
∈[
[①⑨]
[①④]
[①c]
[②f]
[②⑧]
[②①]
[①C]
[③c]
[③g]
[②⑤]
[②②]
一.
[①h]
.数
[①B]
数/
[①i]
[③e]
[①①]
[④d]
[④e]
[③b]
[⑤a]
[①A]
[②⑧]
[②⑦]
[①d]
[②j]
://
′∈
[②④
[⑤e]
...
...................
…………………………………………………③
[③F]
[①o]
]∧′=[
∪φ∈
②c
[③①]
[①E]
Ψ
.日
[②d]
[②
[②⑦]
[②②]
[③e]
[①i]
[①B]
[①h]
[①d]
[①g]
[①②]
[②a]
[⑩]
[①e]
[②h]
[②⑥]
[③d]
[②⑩]
元/吨
[②⑩]
[①]
::
[②]
[③]
[④]
[⑤]
[⑥]
[⑦]
[⑧]
[⑨]
……
——
?
,
'
?
·
———
──
?
<
>
[
]
(
)
-
+
×
/
В
"
;
#
@
γ
μ
φ
φ.
×
Δ
sub
exp
sup
sub
Lex
+ξ
-β
<±
<Δ
<λ
<φ
=
=☆
>λ
_
~±
[⑤f]
[⑤d]
[②i]
[②G]
[①f]
......
[③⑩]
第二
一番
一直
一个
一些
许多
有的是
也就是说
末##末
哎呀
哎哟
俺们
按照
吧哒
罢了
本着
比方
比如
鄙人
彼此
别的
别说
并且
不比
不成
不单
不但
不独
不管
不光
不过
不仅
不拘
不论
不怕
不然
不如
不特
不惟
不问
不只
朝着
趁着
除此之外
除非
除了
此间
此外
从而
但是
当着
的话
等等
叮咚
对于
多少
而况
而且
而是
而外
而言
而已
尔后
反过来
反过来说
反之
非但
非徒
否则
嘎登
各个
各位
各种
各自
根据
故此
固然
关于
果然
果真
哈哈
何处
何况
何时
哼唷
呼哧
还是
还有
换句话说
换言之
或是
或者
极了
及其
及至
即便
即或
即令
即若
即使
几时
既然
既是
继而
加之
假如
假若
假使
鉴于
较之
接着
结果
紧接着
进而
尽管
经过
就是
就是说
具体地说
具体说来
开始
开外
可见
可是
可以
况且
来着
例如
连同
两者
另外
另一方面
慢说
漫说
每当
莫若
某个
某些
哪边
哪儿
哪个
哪里
哪年
哪怕
哪天
哪些
哪样
那边
那儿
那个
那会儿
那里
那么
那么些
那么样
那时
那些
那样
乃至
你们
宁可
宁肯
宁愿
啪达
旁人
凭借
其次
其二
其他
其它
其一
其余
其中
起见
起见
岂但
恰恰相反
前后
前者
然而
然后
然则
人家
任何
任凭
如此
如果
如何
如其
如若
如上所述
若非
若是
上下
尚且
设若
设使
甚而
甚么
甚至
省得
时候
什么
什么样
使得
是的
首先
谁知
顺着
似的
虽然
虽说
虽则
随着
所以
他们
他人
它们
她们
倘或
倘然
倘若
倘使
通过
同时
万一
为何
为了
为什么
为着
嗡嗡
我们
呜呼
乌乎
无论
无宁
毋宁
相对而言
向着
沿
沿着
要不
要不然
要不是
要么
要是
也罢
也好
一般
一旦
一方面
一来
一切
一样
一则
依照
以便
以及
以免
以至
以至于
以致
抑或
因此
因而
因为
由此可见
由于
有的
有关
有些
于是
于是乎
与此同时
与否
与其
越是
云云
再说
再者
在下
咱们
怎么
怎么办
怎么样
怎样
照着
这边
这儿
这个
这会儿
这就是说
这里
这么
这么点儿
这么些
这么样
这时
这些
这样
正如
之类
之所以
之一
只是
只限
只要
只有
至于
诸位
着呢
自从
自个儿
自各儿
自己
自家
自身
综上所述
总的来看
总的来说
总的说来
总而言之
总之
纵令
纵然
纵使
遵照
作为
喔唷

View File

@@ -1,273 +0,0 @@
"""检索管理器
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
"""
import time
from dataclasses import dataclass
from typing import List
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from ..kb_helper import KBHelper
from astrbot import logger
@dataclass
class RetrievalResult:
"""检索结果"""
chunk_id: str
doc_id: str
doc_name: str
kb_id: str
kb_name: str
content: str
score: float
metadata: dict
class RetrievalManager:
"""检索管理器
职责:
- 协调稠密检索、稀疏检索和 Rerank
- 结果融合和排序
"""
def __init__(
self,
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBSQLiteDatabase,
):
"""初始化检索管理器
Args:
vec_db_factory: 向量数据库工厂
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
"""
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
async def retrieve(
self,
query: str,
kb_ids: List[str],
kb_id_helper_map: dict[str, KBHelper],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> List[RetrievalResult]:
"""混合检索
流程:
1. 稠密检索 (向量相似度)
2. 稀疏检索 (BM25)
3. 结果融合 (RRF)
4. Rerank 重排序
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_m_final: 最终返回数量
enable_rerank: 是否启用 Rerank
Returns:
List[RetrievalResult]: 检索结果列表
"""
if not kb_ids:
return []
kb_options: dict = {}
new_kb_ids = []
for kb_id in kb_ids:
kb_helper = kb_id_helper_map.get(kb_id)
if kb_helper:
kb = kb_helper.kb
kb_options[kb_id] = {
"top_k_dense": kb.top_k_dense or 50,
"top_k_sparse": kb.top_k_sparse or 50,
"top_m_final": kb.top_m_final or 5,
"vec_db": kb_helper.vec_db,
"rerank_provider_id": kb.rerank_provider_id,
}
new_kb_ids.append(kb_id)
else:
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
kb_ids = new_kb_ids
# 1. 稠密检索
time_start = time.time()
dense_results = await self._dense_retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
)
# 2. 稀疏检索
time_start = time.time()
sparse_results = await self.sparse_retriever.retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
)
# 3. 结果融合
time_start = time.time()
fused_results = await self.rank_fusion.fuse(
dense_results=dense_results,
sparse_results=sparse_results,
top_k=top_k_fusion,
)
time_end = time.time()
logger.debug(
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
)
# 4. 转换为 RetrievalResult (获取元数据)
retrieval_results = []
for fr in fused_results:
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
if metadata_dict:
retrieval_results.append(
RetrievalResult(
chunk_id=fr.chunk_id,
doc_id=fr.doc_id,
doc_name=metadata_dict["document"].doc_name,
kb_id=fr.kb_id,
kb_name=metadata_dict["knowledge_base"].kb_name,
content=fr.content,
score=fr.score,
metadata={
"chunk_index": fr.chunk_index,
"char_count": len(fr.content),
},
)
)
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
and vec_db.rerank_provider
and rerank_pi
and rerank_pi == vec_db.rerank_provider.meta().id
):
first_rerank = vec_db.rerank_provider
break
if first_rerank and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=first_rerank,
)
return retrieval_results[:top_m_final]
async def _dense_retrieve(
self,
query: str,
kb_ids: List[str],
kb_options: dict,
):
"""稠密检索 (向量相似度)
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
Returns:
List[Result]: 检索结果列表
"""
all_results: list[Result] = []
for kb_id in kb_ids:
if kb_id not in kb_options:
continue
try:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
dense_k = int(kb_options[kb_id]["top_k_dense"])
vec_results = await vec_db.retrieve(
query=query,
k=dense_k,
fetch_k=dense_k * 2,
rerank=False, # 稠密检索阶段不进行 rerank
metadata_filters={"kb_id": kb_id},
)
all_results.extend(vec_results)
except Exception as e:
from astrbot.core import logger
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
continue
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)
# return all_results[: len(all_results) // len(kb_ids)]
return all_results
async def _rerank(
self,
query: str,
results: List[RetrievalResult],
top_k: int,
rerank_provider: RerankProvider,
) -> List[RetrievalResult]:
"""Rerank 重排序
Args:
query: 查询文本
results: 检索结果列表
top_k: 返回结果数量
Returns:
List[RetrievalResult]: 重排序后的结果列表
"""
if not results:
return []
# 准备文档列表
docs = [r.content for r in results]
# 调用 Rerank Provider
rerank_results = await rerank_provider.rerank(
query=query,
documents=docs,
)
# 更新分数并重新排序
reranked_list = []
for rerank_result in rerank_results:
idx = rerank_result.index
if idx < len(results):
result = results[idx]
result.score = rerank_result.relevance_score
reranked_list.append(result)
reranked_list.sort(key=lambda x: x.score, reverse=True)
return reranked_list[:top_k]

View File

@@ -1,138 +0,0 @@
"""检索结果融合器
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
"""
import json
from dataclasses import dataclass
from astrbot.core.db.vec_db.base import Result
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
@dataclass
class FusedResult:
"""融合后的检索结果"""
chunk_id: str
chunk_index: int
doc_id: str
kb_id: str
content: str
score: float
class RankFusion:
"""检索结果融合器
职责:
- 融合稠密检索和稀疏检索的结果
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
"""初始化结果融合器
Args:
kb_db: 知识库数据库实例
k: RRF 参数,用于平滑排名
"""
self.kb_db = kb_db
self.k = k
async def fuse(
self,
dense_results: list[Result],
sparse_results: list[SparseResult],
top_k: int = 20,
) -> list[FusedResult]:
"""融合稠密和稀疏检索结果
RRF 公式:
score(doc) = sum(1 / (k + rank_i))
Args:
dense_results: 稠密检索结果
sparse_results: 稀疏检索结果
top_k: 返回结果数量
Returns:
List[FusedResult]: 融合后的结果列表
"""
# 1. 构建排名映射
dense_ranks = {
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
} # 这里的 doc_id 实际上是 chunk_id
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
# 2. 收集所有唯一的 ID
# 需要统一为 chunk_id
all_chunk_ids = set()
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
# 处理稀疏检索结果
for r in sparse_results:
all_chunk_ids.add(r.chunk_id)
chunk_id_to_sparse[r.chunk_id] = r
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
for r in dense_results:
vec_doc_id = r.data["doc_id"]
all_chunk_ids.add(vec_doc_id)
vec_doc_id_to_dense[vec_doc_id] = r
# 3. 计算 RRF 分数
rrf_scores: dict[str, float] = {}
for identifier in all_chunk_ids:
score = 0.0
# 来自稠密检索的贡献
if identifier in dense_ranks:
score += 1.0 / (self.k + dense_ranks[identifier])
# 来自稀疏检索的贡献
if identifier in sparse_ranks:
score += 1.0 / (self.k + sparse_ranks[identifier])
rrf_scores[identifier] = score
# 4. 排序
sorted_ids = sorted(
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
)[:top_k]
# 5. 构建融合结果
fused_results = []
for identifier in sorted_ids:
# 优先从稀疏检索获取完整信息
if identifier in chunk_id_to_sparse:
sr = chunk_id_to_sparse[identifier]
fused_results.append(
FusedResult(
chunk_id=sr.chunk_id,
chunk_index=sr.chunk_index,
doc_id=sr.doc_id,
kb_id=sr.kb_id,
content=sr.content,
score=rrf_scores[identifier],
)
)
elif identifier in vec_doc_id_to_dense:
# 从向量检索获取信息,需要从数据库获取块的详细信息
vec_result = vec_doc_id_to_dense[identifier]
chunk_md = json.loads(vec_result.data["metadata"])
fused_results.append(
FusedResult(
chunk_id=identifier,
chunk_index=chunk_md["chunk_index"],
doc_id=chunk_md["kb_doc_id"],
kb_id=chunk_md["kb_id"],
content=vec_result.data["text"],
score=rrf_scores[identifier],
)
)
return fused_results

View File

@@ -1,130 +0,0 @@
"""稀疏检索器
使用 BM25 算法进行基于关键词的文档检索
"""
import jieba
import os
import json
from dataclasses import dataclass
from rank_bm25 import BM25Okapi
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
@dataclass
class SparseResult:
"""稀疏检索结果"""
chunk_index: int
chunk_id: str
doc_id: str
kb_id: str
content: str
score: float
class SparseRetriever:
"""BM25 稀疏检索器
职责:
- 基于关键词的文档检索
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化稀疏检索器
Args:
kb_db: 知识库数据库实例
"""
self.kb_db = kb_db
self._index_cache = {} # 缓存 BM25 索引
with open(
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
encoding="utf-8",
) as f:
self.hit_stopwords = {
word.strip() for word in set(f.read().splitlines()) if word.strip()
}
async def retrieve(
self,
query: str,
kb_ids: list[str],
kb_options: dict,
) -> list[SparseResult]:
"""执行稀疏检索
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
kb_options: 每个知识库的检索选项
Returns:
List[SparseResult]: 检索结果列表
"""
# 1. 获取所有相关块
top_k_sparse = 0
chunks = []
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
if not vec_db:
continue
result = await vec_db.document_storage.get_documents(
metadata_filters={}, limit=None, offset=None
)
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
result = [
{
"chunk_id": doc["doc_id"],
"chunk_index": chunk_md["chunk_index"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": kb_id,
"text": doc["text"],
}
for doc, chunk_md in zip(result, chunk_mds)
]
chunks.extend(result)
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
if not chunks:
return []
# 2. 准备文档和索引
corpus = [chunk["text"] for chunk in chunks]
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
tokenized_corpus = [
[word for word in doc if word not in self.hit_stopwords]
for doc in tokenized_corpus
]
# 3. 构建 BM25 索引
bm25 = BM25Okapi(tokenized_corpus)
# 4. 执行检索
tokenized_query = list(jieba.cut(query))
tokenized_query = [
word for word in tokenized_query if word not in self.hit_stopwords
]
scores = bm25.get_scores(tokenized_query)
# 5. 排序并返回 Top-K
results = []
for idx, score in enumerate(scores):
chunk = chunks[idx]
results.append(
SparseResult(
chunk_id=chunk["chunk_id"],
chunk_index=chunk["chunk_index"],
doc_id=chunk["doc_id"],
kb_id=chunk["kb_id"],
content=chunk["text"],
score=float(score),
)
)
results.sort(key=lambda x: x.score, reverse=True)
# return results[: len(results) // len(kb_ids)]
return results[:top_k_sparse]

View File

@@ -37,7 +37,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
class ComponentType(str, Enum):
class ComponentType(Enum):
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
@@ -108,7 +108,7 @@ class BaseMessageComponent(BaseModel):
class Plain(BaseMessageComponent):
type = ComponentType.Plain
type: ComponentType = "Plain"
text: str
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
@@ -130,7 +130,7 @@ class Plain(BaseMessageComponent):
class Face(BaseMessageComponent):
type = ComponentType.Face
type: ComponentType = "Face"
id: int
def __init__(self, **_):
@@ -138,7 +138,7 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent):
type = ComponentType.Record
type: ComponentType = "Record"
file: T.Optional[str] = ""
magic: T.Optional[bool] = False
url: T.Optional[str] = ""
@@ -165,24 +165,19 @@ class Record(BaseMessageComponent):
return Record(file=url, **_)
raise Exception("not a valid url")
@staticmethod
def fromBase64(bs64_data: str, **_):
return Record(file=f"base64://{bs64_data}", **_)
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
return self.file[8:]
elif self.file.startswith("http"):
if self.file and self.file.startswith("file:///"):
file_path = self.file[8:]
return file_path
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file.startswith("base64://"):
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -191,7 +186,8 @@ class Record(BaseMessageComponent):
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
return os.path.abspath(self.file)
file_path = self.file
return os.path.abspath(file_path)
else:
raise Exception(f"not a valid file: {self.file}")
@@ -202,14 +198,12 @@ class Record(BaseMessageComponent):
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
if self.file and self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file.startswith("http"):
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file.startswith("base64://"):
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
@@ -243,7 +237,7 @@ class Record(BaseMessageComponent):
class Video(BaseMessageComponent):
type = ComponentType.Video
type: ComponentType = "Video"
file: str
cover: T.Optional[str] = ""
c: T.Optional[int] = 2
@@ -329,7 +323,7 @@ class Video(BaseMessageComponent):
class At(BaseMessageComponent):
type = ComponentType.At
type: ComponentType = "At"
qq: T.Union[int, str] # 此处str为all时代表所有人
name: T.Optional[str] = ""
@@ -351,28 +345,28 @@ class AtAll(At):
class RPS(BaseMessageComponent): # TODO
type = ComponentType.RPS
type: ComponentType = "RPS"
def __init__(self, **_):
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
type = ComponentType.Dice
type: ComponentType = "Dice"
def __init__(self, **_):
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
type = ComponentType.Shake
type: ComponentType = "Shake"
def __init__(self, **_):
super().__init__(**_)
class Anonymous(BaseMessageComponent): # TODO
type = ComponentType.Anonymous
type: ComponentType = "Anonymous"
ignore: T.Optional[bool] = False
def __init__(self, **_):
@@ -380,7 +374,7 @@ class Anonymous(BaseMessageComponent): # TODO
class Share(BaseMessageComponent):
type = ComponentType.Share
type: ComponentType = "Share"
url: str
title: str
content: T.Optional[str] = ""
@@ -391,7 +385,7 @@ class Share(BaseMessageComponent):
class Contact(BaseMessageComponent): # TODO
type = ComponentType.Contact
type: ComponentType = "Contact"
_type: str # type 字段冲突
id: T.Optional[int] = 0
@@ -400,7 +394,7 @@ class Contact(BaseMessageComponent): # TODO
class Location(BaseMessageComponent): # TODO
type = ComponentType.Location
type: ComponentType = "Location"
lat: float
lon: float
title: T.Optional[str] = ""
@@ -411,7 +405,7 @@ class Location(BaseMessageComponent): # TODO
class Music(BaseMessageComponent):
type = ComponentType.Music
type: ComponentType = "Music"
_type: str
id: T.Optional[int] = 0
url: T.Optional[str] = ""
@@ -428,7 +422,7 @@ class Music(BaseMessageComponent):
class Image(BaseMessageComponent):
type = ComponentType.Image
type: ComponentType = "Image"
file: T.Optional[str] = ""
_type: T.Optional[str] = ""
subType: T.Optional[int] = 0
@@ -471,15 +465,14 @@ class Image(BaseMessageComponent):
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
return url[8:]
elif url.startswith("http"):
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
image_file_path = url[8:]
return image_file_path
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url.startswith("base64://"):
elif url and url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -488,7 +481,8 @@ class Image(BaseMessageComponent):
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
return os.path.abspath(url)
image_file_path = url
return os.path.abspath(image_file_path)
else:
raise Exception(f"not a valid file: {url}")
@@ -499,15 +493,13 @@ class Image(BaseMessageComponent):
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url.startswith("http"):
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url.startswith("base64://"):
elif url and url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
@@ -541,7 +533,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type = ComponentType.Reply
type: ComponentType = "Reply"
id: T.Union[str, int]
"""所引用的消息 ID"""
chain: T.Optional[T.List["BaseMessageComponent"]] = []
@@ -567,7 +559,7 @@ class Reply(BaseMessageComponent):
class RedBag(BaseMessageComponent):
type = ComponentType.RedBag
type: ComponentType = "RedBag"
title: str
def __init__(self, **_):
@@ -575,7 +567,7 @@ class RedBag(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
type: str = ""
id: T.Optional[int] = 0
qq: T.Optional[int] = 0
@@ -585,7 +577,7 @@ class Poke(BaseMessageComponent):
class Forward(BaseMessageComponent):
type = ComponentType.Forward
type: ComponentType = "Forward"
id: str
def __init__(self, **_):
@@ -595,7 +587,7 @@ class Forward(BaseMessageComponent):
class Node(BaseMessageComponent):
"""群合并转发消息"""
type = ComponentType.Node
type: ComponentType = "Node"
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
@@ -647,7 +639,7 @@ class Node(BaseMessageComponent):
class Nodes(BaseMessageComponent):
type = ComponentType.Nodes
type: ComponentType = "Nodes"
nodes: T.List[Node]
def __init__(self, nodes: T.List[Node], **_):
@@ -673,7 +665,7 @@ class Nodes(BaseMessageComponent):
class Xml(BaseMessageComponent):
type = ComponentType.Xml
type: ComponentType = "Xml"
data: str
resid: T.Optional[int] = 0
@@ -682,7 +674,7 @@ class Xml(BaseMessageComponent):
class Json(BaseMessageComponent):
type = ComponentType.Json
type: ComponentType = "Json"
data: T.Union[str, dict]
resid: T.Optional[int] = 0
@@ -693,7 +685,7 @@ class Json(BaseMessageComponent):
class CardImage(BaseMessageComponent):
type = ComponentType.CardImage
type: ComponentType = "CardImage"
file: str
cache: T.Optional[bool] = True
minwidth: T.Optional[int] = 400
@@ -712,7 +704,7 @@ class CardImage(BaseMessageComponent):
class TTS(BaseMessageComponent):
type = ComponentType.TTS
type: ComponentType = "TTS"
text: str
def __init__(self, **_):
@@ -720,7 +712,7 @@ class TTS(BaseMessageComponent):
class Unknown(BaseMessageComponent):
type = ComponentType.Unknown
type: ComponentType = "Unknown"
text: str
def toString(self):
@@ -732,7 +724,7 @@ class File(BaseMessageComponent):
文件消息段
"""
type = ComponentType.File
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file_: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
@@ -862,7 +854,7 @@ class File(BaseMessageComponent):
class WechatEmoji(BaseMessageComponent):
type = ComponentType.WechatEmoji
type: ComponentType = "WechatEmoji"
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""

View File

@@ -1,183 +0,0 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, Personality
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.platform.message_session import MessageSession
from astrbot import logger
DEFAULT_PERSONALITY = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
begin_dialogs=[],
mood_imitation_dialogs=[],
tools=None,
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed="",
)
class PersonaManager:
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
self.db = db_helper
self.acm = acm
default_ps = acm.default_conf.get("provider_settings", {})
self.default_persona: str = default_ps.get("default_personality", "default")
self.personas: list[Persona] = []
self.selected_default_persona: Persona | None = None
self.personas_v3: list[Personality] = []
self.selected_default_persona_v3: Personality | None = None
self.persona_v3_config: list[dict] = []
async def initialize(self):
self.personas = await self.get_all_personas()
self.get_v3_persona_data()
logger.info(f"已加载 {len(self.personas)} 个人格。")
async def get_persona(self, persona_id: str):
"""获取指定 persona 的信息"""
persona = await self.db.get_persona_by_id(persona_id)
if not persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona
async def get_default_persona_v3(
self, umo: str | MessageSession | None = None
) -> Personality:
"""获取默认 persona"""
cfg = self.acm.get_conf(umo)
default_persona_id = cfg.get("provider_settings", {}).get(
"default_personality", "default"
)
if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
try:
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
except Exception:
return DEFAULT_PERSONALITY
async def delete_persona(self, persona_id: str):
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
await self.db.delete_persona(persona_id)
self.personas = [p for p in self.personas if p.persona_id != persona_id]
self.get_v3_persona_data()
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
if not existing_persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
persona = await self.db.update_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
if persona:
for i, p in enumerate(self.personas):
if p.persona_id == persona_id:
self.personas[i] = persona
break
self.get_v3_persona_data()
return persona
async def get_all_personas(self) -> list[Persona]:
"""获取所有 personas"""
return await self.db.get_personas()
async def create_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} already exists.")
new_persona = await self.db.insert_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
self.personas.append(new_persona)
self.get_v3_persona_data()
return new_persona
def get_v3_persona_data(
self,
) -> tuple[list[dict], list[Personality], Personality]:
"""获取 AstrBot <4.0.0 版本的 persona 数据。
Returns:
- list[dict]: 包含 persona 配置的字典列表。
- list[Personality]: 包含 Personality 对象的列表。
- Personality: 默认选择的 Personality 对象。
"""
v3_persona_config = [
{
"prompt": persona.system_prompt,
"name": persona.persona_id,
"begin_dialogs": persona.begin_dialogs or [],
"mood_imitation_dialogs": [], # deprecated
"tools": persona.tools,
}
for persona in self.personas
]
personas_v3: list[Personality] = []
selected_default_persona: Personality | None = None
for persona_cfg in v3_persona_config:
begin_dialogs = persona_cfg.get("begin_dialogs", [])
bd_processed = []
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append(
{
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
}
)
user_turn = not user_turn
try:
persona = Personality(
**persona_cfg,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed="", # deprecated
)
if persona["name"] == self.default_persona:
selected_default_persona = persona
personas_v3.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not selected_default_persona and len(personas_v3) > 0:
# 默认选择第一个
selected_default_persona = personas_v3[0]
if not selected_default_persona:
selected_default_persona = DEFAULT_PERSONALITY
personas_v3.append(selected_default_persona)
self.personas_v3 = personas_v3
self.selected_default_persona_v3 = selected_default_persona
self.persona_v3_config = v3_persona_config
self.selected_default_persona = Persona(
persona_id=selected_default_persona["name"],
system_prompt=selected_default_persona["prompt"],
begin_dialogs=selected_default_persona["begin_dialogs"],
tools=selected_default_persona["tools"] or None,
)
return v3_persona_config, personas_v3, selected_default_persona

View File

@@ -4,6 +4,7 @@ from astrbot.core.message.message_event_result import (
)
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
@@ -20,6 +21,7 @@ STAGES_ORDER = [
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
@@ -32,6 +34,7 @@ __all__ = [
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PlatformCompatibilityStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",

View File

@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str | None = None
self, event: AstrMessageEvent, check_text: str = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()

View File

@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
def check(self, content: str) -> tuple[bool, str]:
def check(self, content: str):
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
return False, ""

View File

@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
def check(self, content: str) -> tuple[bool, str]:
def check(self, content: str) -> bool:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"

View File

@@ -1,7 +1,14 @@
import inspect
import traceback
import typing as T
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star import PluginManager
from .context_utils import call_handler, call_event_hook
from astrbot.api import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
@dataclass
@@ -10,6 +17,97 @@ class PipelineContext:
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
async def call_event_hook(
self,
event: AstrMessageEvent,
hook_type: EventType,
*args,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
"""
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()
async def call_handler(
self,
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -1,102 +0,0 @@
import inspect
import traceback
import typing as T
from astrbot import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
async def call_event_hook(
event: AstrMessageEvent,
hook_type: EventType,
*args,
**kwargs,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
#"""
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args, **kwargs)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return True
return event.is_stopped()

View File

@@ -0,0 +1,56 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
@register_stage
class PlatformCompatibilityStage(Stage):
"""检查所有处理器的平台兼容性。
这个阶段会检查所有处理器是否在当前平台启用如果未启用则设置platform_compatible属性为False。
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化平台兼容性检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 获取当前平台ID
platform_id = event.get_platform_id()
# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
activated_handlers = []
# 标记不兼容的处理器
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
handler.platform_compatible = False
else:
# 确保处理器为平台兼容状态
handler.platform_compatible = True
# 更新已激活的处理器列表
event.set_extra("activated_handlers", activated_handlers)

View File

@@ -1,6 +1,5 @@
import traceback
import asyncio
import random
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
@@ -23,26 +22,6 @@ class PreProcessStage(Stage):
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""在处理事件之前的预处理"""
# 平台特异配置platform_specific.<platform>.pre_ack_emoji
supported = {"telegram", "lark"}
platform = event.get_platform_name()
cfg = (
self.config.get("platform_specific", {})
.get(platform, {})
.get("pre_ack_emoji", {})
) or {}
emojis = cfg.get("emojis") or []
if (
cfg.get("enable", False)
and platform in supported
and emojis
and event.is_at_or_wake_command
):
try:
await event.react(random.choice(emojis))
except Exception as e:
logger.warning(f"{platform} 预回应表情发送失败: {e}")
# 路径映射
if mappings := self.platform_settings.get("path_mapping", []):
# 支持 RecordImage 消息段的路径映射。
@@ -67,9 +46,6 @@ class PreProcessStage(Stage):
ctx = self.plugin_manager.context
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
)
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):

View File

@@ -1,33 +1,33 @@
import abc
import typing as T
from enum import Enum, auto
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponse
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider import Provider
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
from enum import Enum, auto
class AgentState(Enum):
"""Defines the state of the agent."""
"""Agent 状态枚举"""
IDLE = auto() # Initial state
RUNNING = auto() # Currently processing
DONE = auto() # Completed
ERROR = auto() # Error state
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
class BaseAgentRunner(T.Generic[TContext]):
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData
class BaseAgentRunner:
@abc.abstractmethod
async def reset(
self,
provider: Provider,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
async def reset(self) -> None:
"""
Reset the agent to its initial state.
This method should be called before starting a new run.

View File

@@ -1,12 +1,10 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentState
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponseData
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
from ...context import PipelineContext
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageChain,
)
@@ -23,8 +21,8 @@ from mcp.types import (
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
CallToolResult,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
if sys.version_info >= (3, 12):
@@ -33,25 +31,28 @@ else:
from typing_extensions import override
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
# TODO:
# 1. 处理平台不兼容的处理器
class ToolLoopAgent(BaseAgentRunner):
def __init__(
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.provider = provider
self.req = None
self.event = event
self.pipeline_ctx = pipeline_ctx
self._state = AgentState.IDLE
self.final_llm_resp = None
self.streaming = False
@override
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
self.req = req
self.streaming = streaming
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
self.agent_hooks = agent_hooks
self.run_context = run_context
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
@@ -77,12 +78,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
@@ -129,10 +124,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 执行事件钩子
if await self.pipeline_ctx.call_event_hook(
self.event, EventType.OnLLMResponseEvent, llm_resp
):
return
# 返回 LLM 结果
if llm_resp.result_chain:
@@ -196,76 +193,50 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
if not func_tool:
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: 未找到工具 {func_tool_name}",
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if not res:
continue
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
)
continue
valid_params = {} # 参数过滤:只传递函数实际需要的参数
# 获取实际的 handler 函数
if func_tool.handler:
logger.debug(
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}"
)
if func_tool.parameters and func_tool.parameters.get("properties"):
expected_params = set(func_tool.parameters["properties"].keys())
valid_params = {
k: v
for k, v in func_tool_args.items()
if k in expected_params
}
# 记录被忽略的参数
ignored_params = set(func_tool_args.keys()) - set(
valid_params.keys()
)
if ignored_params:
logger.warning(
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}"
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
else:
# 如果没有 handler如 MCP 工具),使用所有参数
valid_params = func_tool_args
logger.warning(f"工具 {func_tool_name} 没有 handler使用所有参数")
try:
await self.agent_hooks.on_tool_start(
self.run_context, func_tool, valid_params
)
except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
executor = self.tool_executor.execute(
tool=func_tool,
run_context=self.run_context,
**valid_params, # 只传递有效的参数
)
_final_resp: CallToolResult | None = None
async for resp in executor: # type: ignore
if isinstance(resp, CallToolResult):
res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent):
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
content=resource.text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
@@ -276,65 +247,43 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(
type="tool_direct_result"
).base64_image(resource.blob)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.run_context.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
else:
# 不应该出现其他类型
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool, func_tool_args, _final_resp
)
yield MessageChain().message("返回的数据类型不受支持。")
else:
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
# 尝试调用工具函数
wrapper = self.pipeline_ctx.call_handler(
self.event, func_tool.handler, **func_tool_args
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
async for resp in wrapper:
if resp is not None:
# Tool 返回结果
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
yield MessageChain().message(resp)
else:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
self.run_context.event.clear_result()
self.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(

View File

@@ -6,9 +6,7 @@ import asyncio
import copy
import json
import traceback
from datetime import timedelta
from collections.abc import AsyncGenerator
from astrbot.core.conversation_mgr import Conversation
from typing import AsyncGenerator, Union
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
@@ -22,292 +20,12 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_handler
from ...context import PipelineContext
from ..agent_runner.tool_loop_agent import ToolLoopAgent
from ..stage import Stage
from ..utils import inject_kb_context
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.star_handler import star_map
from astrbot.core.astr_agent_context import AstrAgentContext
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
if tool.origin == "local":
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
elif tool.origin == "mcp":
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
raise Exception(f"Unknown function origin: {tool.origin}")
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
)
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await run_context.event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx, event=run_context.event
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
streaming=run_context.context.streaming,
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not run_context.event:
raise ValueError("Event must be provided for local function tools.")
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run"):
raise ValueError("Tool must have a valid handler or 'run' method.")
awaitable = tool.handler or getattr(tool, "run")
wrapper = call_handler(
event=run_context.event,
handler=awaitable,
**tool_args,
)
# async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
read_timeout_seconds=timedelta(
seconds=run_context.context.tool_call_timeout
),
)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.event, EventType.OnLLMResponseEvent, llm_response
)
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
astr_event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
class LLMRequestSubStage(Stage):
@@ -323,10 +41,7 @@ class LLMRequestSubStage(Stage):
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.max_step: int = settings.get("max_agent_step", 10)
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
for bwp in self.bot_wake_prefixs:
@@ -338,7 +53,7 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
@@ -350,25 +65,9 @@ class LLMRequestSubStage(Stage):
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
# 获取对话上下文
cid = await conv_mgr.get_curr_conversation_id(umo)
if not cid:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> None | AsyncGenerator[None, None]:
) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
@@ -383,9 +82,6 @@ class LLMRequestSubStage(Stage):
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
@@ -404,14 +100,30 @@ class LLMRequestSubStage(Stage):
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(
event.unified_msg_origin
)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
if not conversation:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
@@ -420,16 +132,8 @@ class LLMRequestSubStage(Stage):
if not req.prompt and not req.image_urls:
return
# 应用知识库
try:
await inject_kb_context(
umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
)
except Exception as e:
logger.error(f"调用知识库时遇到问题: {e}")
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if isinstance(req.contexts, str):
@@ -463,79 +167,98 @@ class LLMRequestSubStage(Stage):
# fix messages
req.contexts = self.fix_messages(req.contexts)
# check provider modalities
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
req.image_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
)
req.func_tool = None
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
# Call Agent
tool_loop_agent = ToolLoopAgent(
provider=provider,
event=event,
pipeline_ctx=self.ctx,
)
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
)
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
tool_call_timeout=self.tool_call_timeout,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
)
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
async def requesting():
step_idx = 0
while step_idx < self.max_step:
step_idx += 1
try:
async for resp in tool_loop_agent.step():
if event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if (
self.show_tool_use
or event.get_platform_name() == "webchat"
):
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
continue
if not self.streaming_response:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if tool_loop_agent.done():
break
except Exception as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
if self.streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use)
)
.set_async_stream(requesting())
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if tool_loop_agent.done():
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
chain = final_llm_resp.result_chain.chain
event.set_result(
MessageEventResult(
chain=chain,
@@ -543,32 +266,19 @@ class LLMRequestSubStage(Stage):
)
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
async for _ in requesting():
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
if not req.conversation:
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
@@ -577,23 +287,7 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:]
if not latest_pair:
return
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
@@ -613,10 +307,19 @@ class LLMRequestSubStage(Stage):
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
unified_msg_origin=event.unified_msg_origin,
title=title,
conversation_id=req.conversation.cid,
event.unified_msg_origin, title=title
)
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
# webchat adapter 中session_id 的格式是 f"webchat!{username}!{cid}"
# TODO: 优化 WebChat 适配器的对话管理
if event.session_id:
username, cid = event.session_id.split("!")[1:3]
db_helper = self.ctx.plugin_manager.context._db
db_helper.update_conversation_title(
user_id=username,
cid=cid,
title=title,
)
async def _save_to_history(
self,
@@ -632,10 +335,6 @@ class LLMRequestSubStage(Stage):
):
return
if not llm_response.completion_text and not req.tool_calls_result:
logger.debug("LLM 响应为空,不保存记录。")
return
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入

View File

@@ -2,7 +2,7 @@
本地 Agent 模式的 AstrBot 插件调用 Stage
"""
from ...context import PipelineContext, call_handler
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -33,16 +33,24 @@ class StarRequestSubStage(Stage):
handlers_parsed_params = {}
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
# 检查处理器是否在当前平台兼容
if (
hasattr(handler, "platform_compatible")
and handler.platform_compatible is False
):
logger.debug(
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
wrapper = call_handler(event, handler.handler, **params)
if handler.handler_module_path not in star_map:
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
wrapper = self.ctx.call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果
@@ -51,7 +59,7 @@ class StarRequestSubStage(Stage):
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()

View File

@@ -1,80 +0,0 @@
from ..context import PipelineContext
from astrbot.core.provider.entities import ProviderRequest
from astrbot.api import logger, sp
async def inject_kb_context(
umo: str,
p_ctx: PipelineContext,
req: ProviderRequest,
) -> None:
"""inject knowledge base context into the provider request
Args:
umo: Unique message object (session ID)
p_ctx: Pipeline context
req: Provider request
"""
kb_mgr = p_ctx.plugin_manager.context.kb_manager
# 1. 优先读取会话级配置
session_config = await sp.session_get(umo, "kb_config", default={})
if session_config and "kb_ids" in session_config:
# 会话级配置
kb_ids = session_config.get("kb_ids", [])
# 如果配置为空列表,明确表示不使用知识库
if not kb_ids:
logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库")
return
top_k = session_config.get("top_k", 5)
# 将 kb_ids 转换为 kb_names
kb_names = []
invalid_kb_ids = []
for kb_id in kb_ids:
kb_helper = await kb_mgr.get_kb(kb_id)
if kb_helper:
kb_names.append(kb_helper.kb.kb_name)
else:
logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}")
invalid_kb_ids.append(kb_id)
if invalid_kb_ids:
logger.warning(
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}"
)
if not kb_names:
return
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
else:
kb_names = p_ctx.astrbot_config.get("kb_names", [])
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
if not kb_names:
return
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
kb_context = await kb_mgr.retrieve(
query=req.prompt,
kb_names=kb_names,
top_k_fusion=top_k_fusion,
top_m_final=top_k,
)
if not kb_context:
return
formatted = kb_context.get("context_text", "")
if formatted:
results = kb_context.get("results", [])
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"

View File

@@ -1,15 +1,17 @@
import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext, call_event_hook
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.components import BaseMessageComponent, ComponentType
from astrbot.core.star.star_handler import EventType
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@@ -112,43 +114,6 @@ class RespondStage(Stage):
# 如果所有组件都为空
return True
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
"""检查是否需要分段回复"""
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
return False
if event.get_platform_name() in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
return False
return True
def _extract_comp(
self,
raw_chain: list[BaseMessageComponent],
extract_types: set[ComponentType],
modify_raw_chain: bool = True,
):
extracted = []
if modify_raw_chain:
remaining = []
for comp in raw_chain:
if comp.type in extract_types:
extracted.append(comp)
else:
remaining.append(comp)
raw_chain[:] = remaining
else:
extracted = [comp for comp in raw_chain if comp.type in extract_types]
return extracted
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -158,19 +123,12 @@ class RespondStage(Stage):
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
logger.info(
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
if result.result_content_type == ResultContentType.STREAMING_RESULT:
if result.async_stream is None:
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_id()})")
logger.info(f"应用流式输出({event.get_platform_name()})")
await event.send_streaming(result.async_stream, use_fallback)
return
elif len(result.chain) > 0:
@@ -186,85 +144,93 @@ class RespondStage(Stage):
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
# 将 Plain 为空的消息段移除
result.chain = [
comp
for comp in result.chain
if not (
isinstance(comp, Comp.Plain)
and (not comp.text or not comp.text.strip())
)
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
non_record_comps = [
c for c in result.chain if not isinstance(c, Comp.Record)
]
# 发送消息链
# Record 需要强制单独发送
need_separately = {ComponentType.Record}
if self.is_seg_reply_required(event):
header_comps = self._extract_comp(
result.chain,
{ComponentType.Reply, ComponentType.At},
modify_raw_chain=True,
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
)
if not result.chain or len(result.chain) == 0:
# may fix #2670
logger.warning(
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
)
return
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
else:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}
for comp in result.chain
):
# may fix #2670
logger.warning(
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
)
return
sep_comps = self._extract_comp(
result.chain,
need_separately,
modify_raw_chain=True,
)
for comp in sep_comps:
chain = MessageChain([comp])
for rcomp in record_comps:
try:
await event.send(chain)
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
chain = MessageChain(result.chain)
if result.chain and len(result.chain) > 0:
try:
await event.send(chain)
except Exception as e:
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
logger.error(f"发送消息失败: {e} chain: {result.chain}")
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
return
try:
await event.send(MessageChain(non_record_comps))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:
logger.debug(
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
event.clear_result()

View File

@@ -36,7 +36,6 @@ class ResultDecorateStage(Stage):
self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
self.t2i_use_network = self.t2i_strategy == "remote"
self.t2i_active_template = ctx.astrbot_config["t2i_active_template"]
self.forward_threshold = ctx.astrbot_config["platform_settings"][
"forward_threshold"
@@ -65,10 +64,9 @@ class ResultDecorateStage(Stage):
]
self.content_safe_check_stage = None
if self.content_safe_check_reply:
for stage_cls in registered_stages:
if stage_cls.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx)
for stage in registered_stages:
if stage.__class__.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage
async def process(
self, event: AstrMessageEvent
@@ -100,7 +98,7 @@ class ResultDecorateStage(Stage):
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:
@@ -183,60 +181,56 @@ class ResultDecorateStage(Stage):
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
and SessionServiceManager.should_process_tts_request(event)
):
if not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
)
else:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
)
new_chain.append(comp)
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
result.chain = new_chain
else:
new_chain.append(comp)
result.chain = new_chain
# 文本转图片
elif (
@@ -252,10 +246,7 @@ class ResultDecorateStage(Stage):
render_start = time.time()
try:
url = await html_renderer.render_t2i(
plain_str,
return_url=True,
use_network=self.t2i_use_network,
template_name=self.t2i_active_template,
plain_str, return_url=True, use_network=self.t2i_use_network
)
except BaseException:
logger.error("文本转图片失败,使用文本发送。")
@@ -279,6 +270,7 @@ class ResultDecorateStage(Stage):
result.chain = [Image.fromFileSystem(url)]
# 触发转发消息
has_forwarded = False
if event.get_platform_name() == "aiocqhttp":
word_cnt = 0
for comp in result.chain:
@@ -289,9 +281,9 @@ class ResultDecorateStage(Stage):
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
)
result.chain = [node]
has_forwarded = True
has_plain = any(isinstance(item, Plain) for item in result.chain)
if has_plain:
if not has_forwarded:
# at 回复
if (
self.reply_with_mention

View File

@@ -11,17 +11,16 @@ class PipelineScheduler:
def __init__(self, context: PipelineContext):
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__)
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
self.stages = [] # 存储阶段实例
async def initialize(self):
"""初始化管道调度器时, 初始化所有阶段"""
for stage_cls in registered_stages:
stage_instance = stage_cls() # 创建实例
await stage_instance.initialize(self.ctx)
self.stages.append(stage_instance)
for stage in registered_stages:
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
"""依次执行各个阶段
@@ -30,9 +29,9 @@ class PipelineScheduler:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(self.stages)):
stage = self.stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coroutine = stage.process(
event
) # 调用阶段的process方法, 返回协程或者异步生成器
@@ -74,7 +73,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
if event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -11,8 +11,7 @@ class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
pass
async def process(
self, event: AstrMessageEvent
@@ -20,14 +19,4 @@ class SessionStatusCheckStage(Stage):
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
conv_id = await self.conv_mgr.get_curr_conversation_id(
event.unified_msg_origin
)
if not conv_id:
await self.conv_mgr.new_conversation(
event.unified_msg_origin, platform_id=event.get_platform_id()
)
event.stop_event()

View File

@@ -1,15 +1,15 @@
from __future__ import annotations
import abc
from typing import List, AsyncGenerator, Union, Type
from typing import List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
def register_stage(cls):
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
registered_stages.append(cls)
registered_stages.append(cls())
return cls

View File

@@ -5,7 +5,6 @@ from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -113,17 +112,8 @@ class WakingCheckStage(Stage):
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
# 将 plugins_name 设置到 event 中
enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"])
if enabled_plugins_name == ["*"]:
# 如果是 *,则表示所有插件都启用
event.plugins_name = None
else:
event.plugins_name = enabled_plugins_name
logger.debug(f"enabled_plugins_name: {enabled_plugins_name}")
for handler in star_handlers_registry.get_handlers_by_event_type(
EventType.AdapterMessageEvent, plugins_name=event.plugins_name
EventType.AdapterMessageEvent
):
# filter 需满足 AND 逻辑关系
passed = True
@@ -171,15 +161,11 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
is_group_cmd_handler = any(
isinstance(f, CommandGroupFilter) for f in handler.event_filters
)
if not is_group_cmd_handler:
activated_handlers.append(handler)
if "parsed_params" in event.get_extra(default={}):
handlers_parsed_params[handler.handler_full_name] = (
event.get_extra("parsed_params")
)
activated_handlers.append(handler)
if "parsed_params" in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
event._extras.pop("parsed_params", None)

View File

@@ -3,10 +3,9 @@ import asyncio
import re
import hashlib
import uuid
from dataclasses import dataclass
from typing import List, Union, Optional, AsyncGenerator
from typing import List, Union, Optional, AsyncGenerator, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
Plain,
@@ -24,7 +23,21 @@ from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
@dataclass
class MessageSesion:
platform_name: str
message_type: MessageType
session_id: str
def __str__(self):
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
@staticmethod
def from_str(session_str: str):
platform_name, message_type, session_id = session_str.split(":")
return MessageSesion(platform_name, MessageType(message_type), session_id)
class AstrMessageEvent(abc.ABC):
@@ -49,15 +62,15 @@ class AstrMessageEvent(abc.ABC):
"""是否唤醒(是否通过 WakingStage)"""
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras: dict[str, Any] = {}
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
platform_name=platform_meta.name,
message_type=message_obj.type,
session_id=session_id,
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self._result: MessageEventResult | None = None
self._result: MessageEventResult = None
"""消息事件的结果"""
self._has_send_oper = False
@@ -65,23 +78,13 @@ class AstrMessageEvent(abc.ABC):
self.call_llm = False
"""是否在此消息事件中禁止默认的 LLM 请求"""
self.plugins_name: list[str] | None = None
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
# back_compability
self.platform = platform_meta
def get_platform_name(self):
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
NOTE: 用户可能会同时运行多个相同类型的平台适配器。"""
return self.platform_meta.name
def get_platform_id(self):
"""获取这个事件所属的平台的 ID。
NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。
"""
return self.platform_meta.id
def get_message_str(self) -> str:
@@ -90,10 +93,8 @@ class AstrMessageEvent(abc.ABC):
"""
return self.message_str
def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str:
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
outline = ""
if not chain:
return outline
for i in chain:
if isinstance(i, Plain):
outline += i.text
@@ -175,19 +176,18 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(self, key: str | None = None, default=None) -> Any:
def get_extra(self, key=None):
"""
获取额外的信息。
"""
if key is None:
return self._extras
return self._extras.get(key, default)
return self._extras.get(key, None)
def clear_extra(self):
"""
清除额外的信息。
"""
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
self._extras.clear()
def is_private_chat(self) -> bool:
@@ -263,9 +263,6 @@ class AstrMessageEvent(abc.ABC):
"""
if isinstance(result, str):
result = MessageEventResult().message(result)
# 兼容外部插件或调用方传入的 chain=None 的情况,确保为可迭代列表
if isinstance(result, MessageEventResult) and result.chain is None:
result.chain = []
self._result = result
def stop_event(self):
@@ -417,16 +414,6 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def react(self, emoji: str):
"""
对消息添加表情回应。
默认实现为发送一条包含该表情的消息。
注意:此实现并不一定符合所有平台的原生“表情回应”行为。
如需支持平台原生的消息反应功能,请在对应平台的子类中重写本方法。
"""
await self.send(MessageChain([Plain(emoji)]))
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息返回当前群聊的数据。

View File

@@ -55,7 +55,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group: Group # 群组
group_id: str = "" # 群组id如果为私聊则为空
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -64,28 +64,6 @@ class AstrBotMessage:
def __init__(self) -> None:
self.timestamp = int(time.time())
self.group = None
def __str__(self) -> str:
return str(self.__dict__)
@property
def group_id(self) -> str:
"""
向后兼容的 group_id 属性
群组id如果为私聊则为空
"""
if self.group:
return self.group.group_id
return ""
@group_id.setter
def group_id(self, value: str):
"""设置 group_id"""
if value:
if self.group:
self.group.group_id = value
else:
self.group = Group(group_id=value)
else:
self.group = None

View File

@@ -6,7 +6,6 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType
from .sources.webchat.webchat_adapter import WebChatAdapter
@@ -19,9 +18,6 @@ class PlatformManager:
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
这个配置中的 unique_session 需要特殊处理,
约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
self.event_queue = event_queue
async def initialize(self):
@@ -67,43 +63,25 @@ class PlatformManager:
WeChatPadProAdapter, # noqa: F401
)
case "lark":
from .sources.lark.lark_adapter import (
LarkPlatformAdapter, # noqa: F401
)
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram":
from .sources.telegram.tg_adapter import (
TelegramPlatformAdapter, # noqa: F401
)
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
case "wecom":
from .sources.wecom.wecom_adapter import (
WecomPlatformAdapter, # noqa: F401
)
case "wecom_ai_bot":
from .sources.wecom_ai_bot.wecomai_adapter import (
WecomAIBotAdapter, # noqa: F401
)
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa: F401
WeixinOfficialAccountPlatformAdapter, # noqa
)
case "discord":
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "misskey":
from .sources.misskey.misskey_adapter import (
MisskeyPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
case "satori":
from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
@@ -132,17 +110,6 @@ class PlatformManager:
)
)
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnPlatformLoadedEvent
)
for handler in handlers:
try:
logger.info(
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(self, task: asyncio.Task):
try:

View File

@@ -1,28 +0,0 @@
from astrbot.core.platform.message_type import MessageType
from dataclasses import dataclass
@dataclass
class MessageSession:
"""描述一条消息在 AstrBot 中对应的会话的唯一标识。
如果您需要实例化 MessageSession请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。"""
platform_name: str
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
message_type: MessageType
session_id: str
platform_id: str = None
def __str__(self):
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
def __post_init__(self):
self.platform_id = self.platform_name
@staticmethod
def from_str(session_str: str):
platform_id, message_type, session_id = session_str.split(":")
return MessageSession(platform_id, MessageType(message_type), session_id)
MessageSesion = MessageSession # back compatibility

View File

@@ -5,7 +5,7 @@ from asyncio import Queue
from .platform_metadata import PlatformMetadata
from .astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from .message_session import MessageSesion
from .astr_message_event import MessageSesion
from astrbot.core.utils.metrics import Metric

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
@dataclass
class PlatformMetadata:
name: str
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
"""平台的名称"""
description: str
"""平台的描述"""
id: str = None
@@ -14,5 +14,3 @@ class PlatformMetadata:
"""平台的默认配置模板"""
adapter_display_name: str = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""

View File

@@ -13,12 +13,10 @@ def register_platform_adapter(
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None,
logo_path: str = None,
):
"""用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
"""
def decorator(cls):
@@ -41,7 +39,6 @@ def register_platform_adapter(
description=desc,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls

View File

@@ -67,19 +67,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
session_id: str,
messages: list[dict],
):
# session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None
if is_group and isinstance(session_id, int):
await bot.send_group_msg(group_id=session_id, message=messages)
elif not is_group and isinstance(session_id, int):
await bot.send_private_msg(user_id=session_id, message=messages)
elif isinstance(event, Event): # 最后兜底
if event:
await bot.send(event=event, message=messages)
elif is_group:
await bot.send_group_msg(group_id=session_id, message=messages)
else:
raise ValueError(
f"无法发送消息:缺少有效的数字 session_id({session_id}) 或 event({event})"
)
await bot.send_private_msg(user_id=session_id, message=messages)
@classmethod
async def send_message(
@@ -90,15 +83,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
is_group: bool = False,
session_id: str = None,
):
"""发送消息至 QQ 协议端aiocqhttp
Args:
bot (CQHttp): aiocqhttp 机器人实例
message_chain (MessageChain): 要发送的消息链
event (Event | None, optional): aiocqhttp 事件对象.
is_group (bool, optional): 是否为群消息.
session_id (str | None, optional): 会话 ID群号或 QQ 号
"""
"""发送消息"""
# 转发消息、文件消息不能和普通消息混在一起发送
send_one_by_one = any(
@@ -137,15 +122,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
"""发送消息"""
event = getattr(self.message_obj, "raw_message", None)
is_group = bool(self.get_group_id())
session_id = self.get_group_id() if is_group else self.get_sender_id()
event = self.message_obj.raw_message
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
is_group = False
if self.get_group_id():
is_group = True
session_id = self.get_group_id()
else:
session_id = self.get_sender_id()
await self.send_message(
bot=self.bot,
message_chain=message,
event=event, # 不强制要求一定是 Event
event=event,
is_group=is_group,
session_id=session_id,
)

Some files were not shown because too many files have changed in this diff Show More