Compare commits
105 Commits
v4.5.1
...
feat/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
944b81ce09 | ||
|
|
4e2154feb7 | ||
|
|
604958898c | ||
|
|
a093f5ad0a | ||
|
|
a7e9a7f30c | ||
|
|
5d1e9de096 | ||
|
|
89da4eb747 | ||
|
|
8899a1dee1 | ||
|
|
384a687ec3 | ||
|
|
70cfdd2f8b | ||
|
|
bdbd2f009a | ||
|
|
164e0d26e0 | ||
|
|
cb087b5ff9 | ||
|
|
1d3928d145 | ||
|
|
6dc3d161e7 | ||
|
|
e9805ba205 | ||
|
|
d5280dcd88 | ||
|
|
67a9663eff | ||
|
|
77dd89b8eb | ||
|
|
8e511bf14b | ||
|
|
164a4226ea | ||
|
|
6d6fefc435 | ||
|
|
aa59532287 | ||
|
|
8488c9aeab | ||
|
|
676f9fd4ff | ||
|
|
1935ce4700 | ||
|
|
e760956353 | ||
|
|
be3e5f3f8b | ||
|
|
cdf617feac | ||
|
|
afb56cf707 | ||
|
|
cd2556ab94 | ||
|
|
cf4a5d9ea4 | ||
|
|
0747099cac | ||
|
|
323ec29b02 | ||
|
|
ae81d70685 | ||
|
|
270c89c12f | ||
|
|
c7a58252fe | ||
|
|
47ad8c86e5 | ||
|
|
937e879e5e | ||
|
|
1ecf26eead | ||
|
|
adbb84530a | ||
|
|
6cf169f4f2 | ||
|
|
5ab9ea12c0 | ||
|
|
fd9cb703db | ||
|
|
388c1ab16d | ||
|
|
f867c2a271 | ||
|
|
605bb2cb90 | ||
|
|
5ea15dde5a | ||
|
|
3ca545c4c7 | ||
|
|
e200835074 | ||
|
|
3a90348353 | ||
|
|
5a11d8f0ee | ||
|
|
824af5eeea | ||
|
|
08ec787491 | ||
|
|
b062e83d54 | ||
|
|
17422ba9c3 | ||
|
|
6849af2bad | ||
|
|
09c3da64f9 | ||
|
|
2c8470e8ac | ||
|
|
c4ea3db73d | ||
|
|
89e79863f6 | ||
|
|
d19945009f | ||
|
|
c77256ee0e | ||
|
|
7d823af627 | ||
|
|
3957861878 | ||
|
|
6ac43c600e | ||
|
|
27af9ebb6b | ||
|
|
b360c8446e | ||
|
|
6d00717655 | ||
|
|
bb5f06498e | ||
|
|
aca5743ab6 | ||
|
|
6903032f7e | ||
|
|
1ce0ff87bd | ||
|
|
e39d6bae0b | ||
|
|
8028e9e9a6 | ||
|
|
817f20ea01 | ||
|
|
ad5579a2f4 | ||
|
|
81a689a79b | ||
|
|
1893dd8336 | ||
|
|
021ca8175b | ||
|
|
39d6207fe1 | ||
|
|
23ce687229 | ||
|
|
3715312fd2 | ||
|
|
8196922cac | ||
|
|
8089ad91da | ||
|
|
2930cc3fd8 | ||
|
|
0e841a8b25 | ||
|
|
67fa1611cc | ||
|
|
91136bb9f7 | ||
|
|
7c050d1adc | ||
|
|
a0690a6afc | ||
|
|
c51609b261 | ||
|
|
72148f66eb | ||
|
|
a04993a2bb | ||
|
|
74f845b06d | ||
|
|
50144ddcae | ||
|
|
94bf3b8195 | ||
|
|
e190bbeeed | ||
|
|
92abc43c9d | ||
|
|
c8e34ff26f | ||
|
|
630df3e76e | ||
|
|
bdbf382201 | ||
|
|
00eefc82db | ||
|
|
dc97080837 | ||
|
|
0b7fc29ac4 |
@@ -1,9 +1,9 @@
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
# github acions
|
||||
# github actions
|
||||
.git
|
||||
.github/
|
||||
.*ignore
|
||||
.git/
|
||||
# User-specific stuff
|
||||
.idea/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
@@ -15,10 +15,10 @@ env/
|
||||
venv*/
|
||||
ENV/
|
||||
.conda/
|
||||
README*.md
|
||||
dashboard/
|
||||
data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
.astrbot
|
||||
.astrbot
|
||||
astrbot.lock
|
||||
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
|
||||
不熟悉 JSON ?可以从 [此站](https://plugins.astrbot.app) 右下角提交。
|
||||
|
||||
- type: textarea
|
||||
id: plugin-info
|
||||
|
||||
44
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
44
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,46 +1,44 @@
|
||||
name: '🐛 报告 Bug'
|
||||
name: '🐛 Report Bug / 报告 Bug'
|
||||
title: '[Bug]'
|
||||
description: 提交报告帮助我们改进。
|
||||
description: Submit bug report to help us improve. / 提交报告帮助我们改进。
|
||||
labels: [ 'bug' ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
Thank you for taking the time to report this issue! Please describe your problem accurately. If possible, please provide a reproducible snippet (this will help resolve the issue more quickly). Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 发生了什么
|
||||
description: 描述你遇到的异常
|
||||
label: What happened / 发生了什么
|
||||
description: Description
|
||||
placeholder: >
|
||||
一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
Please provide a clear and specific description of what this exception is. Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 如何复现?
|
||||
label: Reproduce / 如何复现?
|
||||
description: >
|
||||
复现该问题的步骤
|
||||
The steps to reproduce the issue. / 复现该问题的步骤
|
||||
placeholder: >
|
||||
如: 1. 打开 '...'
|
||||
Example: 1. Open '...'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
|
||||
description: >
|
||||
请提供您的 AstrBot 版本和部署方式。
|
||||
label: AstrBot version, deployment method (e.g., Windows Docker Desktop deployment), provider used, and messaging platform used. / AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
|
||||
placeholder: >
|
||||
如: 3.1.8 Docker, 3.1.7 Windows启动器
|
||||
Example: 4.5.7 Docker, 3.1.7 Windows Launcher
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统
|
||||
label: OS
|
||||
description: |
|
||||
你在哪个操作系统上遇到了这个问题?
|
||||
On which operating system did you encounter this problem? / 你在哪个操作系统上遇到了这个问题?
|
||||
multiple: false
|
||||
options:
|
||||
- 'Windows'
|
||||
@@ -53,30 +51,30 @@ body:
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 报错日志
|
||||
label: Logs / 报错日志
|
||||
description: >
|
||||
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
Please provide complete Debug-level logs, such as error logs and screenshots. Don't worry if they're long! Please note that issues with insufficient details or no logs will be closed immediately. Thank you for your understanding. / 如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
placeholder: >
|
||||
请提供完整的报错日志或截图。
|
||||
Please provide a complete error log or screenshot. / 请提供完整的报错日志或截图。
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 你愿意提交 PR 吗?
|
||||
label: Are you willing to submit a PR? / 你愿意提交 PR 吗?
|
||||
description: >
|
||||
这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
|
||||
This is not required, but we would be happy to provide guidance during the contribution process, especially if you already have a good understanding of how to implement the fix. / 这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
|
||||
options:
|
||||
- label: 是的,我愿意提交 PR!
|
||||
- label: Yes!
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
options:
|
||||
- label: >
|
||||
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
I have read and agree to abide by the project's [Code of Conduct](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "感谢您填写我们的表单!"
|
||||
value: "Thank you for filling out our form! / 感谢您填写我们的表单!"
|
||||
|
||||
31
.github/PULL_REQUEST_TEMPLATE.md
vendored
31
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,44 +1,25 @@
|
||||
<!-- 如果有的话,请指定此 PR 旨在解决的 ISSUE 编号。 -->
|
||||
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
|
||||
|
||||
fixes #XYZ
|
||||
|
||||
---
|
||||
|
||||
### Motivation / 动机
|
||||
|
||||
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
|
||||
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
|
||||
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX issue, adds YY feature)-->
|
||||
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX issue,添加了 YY 功能)-->
|
||||
|
||||
### Modifications / 改动点
|
||||
|
||||
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
|
||||
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
|
||||
|
||||
### Verification Steps / 验证步骤
|
||||
|
||||
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤(例如:1. 导航到... 2. 点击...)。-->
|
||||
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
|
||||
- [x] This is NOT a breaking change. / 这不是一个破坏性变更。
|
||||
<!-- If your changes is a breaking change, please uncheck the checkbox above -->
|
||||
|
||||
### Screenshots or Test Results / 运行截图或测试结果
|
||||
|
||||
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
|
||||
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
|
||||
|
||||
### Compatibility & Breaking Changes / 兼容性与破坏性变更
|
||||
|
||||
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
|
||||
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
|
||||
|
||||
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
|
||||
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
|
||||
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
|
||||
|
||||
---
|
||||
|
||||
### Checklist / 检查清单
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
||||
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
||||
|
||||
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
|
||||
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
|
||||
|
||||
143
.github/workflows/docker-image.yml
vendored
143
.github/workflows/docker-image.yml
vendored
@@ -3,18 +3,125 @@ name: Docker Image CI/CD
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
schedule:
|
||||
# Run at 00:00 UTC every day
|
||||
- cron: "0 0 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
publish-docker:
|
||||
build-nightly-image:
|
||||
if: github.event_name == 'schedule'
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
GHCR_OWNER: soulter
|
||||
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
|
||||
|
||||
steps:
|
||||
- name: Pull The Codes
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
- name: Check for new commits today
|
||||
if: github.event_name == 'schedule'
|
||||
id: check-commits
|
||||
run: |
|
||||
# Get commits from the last 24 hours
|
||||
commits=$(git log --since="24 hours ago" --oneline)
|
||||
if [ -z "$commits" ]; then
|
||||
echo "No commits in the last 24 hours, skipping build"
|
||||
echo "has_commits=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "Found commits in the last 24 hours:"
|
||||
echo "$commits"
|
||||
echo "has_commits=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Exit if no commits
|
||||
if: github.event_name == 'schedule' && steps.check-commits.outputs.has_commits == 'false'
|
||||
run: exit 0
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
mkdir -p dist/assets
|
||||
echo $(git rev-parse HEAD) > dist/assets/version
|
||||
cd ..
|
||||
mkdir -p data
|
||||
cp -r dashboard/dist data/
|
||||
|
||||
- name: Determine test image tags
|
||||
id: test-meta
|
||||
run: |
|
||||
short_sha=$(echo "${GITHUB_SHA}" | cut -c1-12)
|
||||
build_date=$(date +%Y%m%d)
|
||||
echo "short_sha=$short_sha" >> $GITHUB_OUTPUT
|
||||
echo "build_date=$build_date" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: env.HAS_GHCR_TOKEN == 'true'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ env.GHCR_OWNER }}
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build nightly image tags list
|
||||
id: test-tags
|
||||
run: |
|
||||
TAGS="${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-latest
|
||||
${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
|
||||
if [ "${{ env.HAS_GHCR_TOKEN }}" = "true" ]; then
|
||||
TAGS="$TAGS
|
||||
ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-latest
|
||||
ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
|
||||
fi
|
||||
echo "tags<<EOF" >> $GITHUB_OUTPUT
|
||||
echo "$TAGS" >> $GITHUB_OUTPUT
|
||||
echo "EOF" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build and Push Nightly Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.test-tags.outputs.tags }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Test Docker image has been built and pushed successfully"
|
||||
|
||||
build-release-image:
|
||||
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v'))
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
GHCR_OWNER: soulter
|
||||
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
- name: Get latest tag (only on manual trigger)
|
||||
id: get-latest-tag
|
||||
@@ -27,21 +134,22 @@ jobs:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||
|
||||
- name: Check if version is pre-release
|
||||
id: check-prerelease
|
||||
- name: Compute release metadata
|
||||
id: release-meta
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
|
||||
else
|
||||
version="${{ github.ref_name }}"
|
||||
version="${GITHUB_REF#refs/tags/}"
|
||||
fi
|
||||
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
|
||||
echo "is_prerelease=true" >> $GITHUB_OUTPUT
|
||||
echo "Version $version is a pre-release, will not push latest tag"
|
||||
echo "Version $version marked as pre-release"
|
||||
else
|
||||
echo "is_prerelease=false" >> $GITHUB_OUTPUT
|
||||
echo "Version $version is a stable release, will push latest tag"
|
||||
echo "Version $version marked as stable"
|
||||
fi
|
||||
echo "version=$version" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
@@ -67,23 +175,24 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: env.HAS_GHCR_TOKEN == 'true'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: Soulter
|
||||
username: ${{ env.GHCR_OWNER }}
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and Push Docker to DockerHub and Github GHCR
|
||||
- name: Build and Push Release Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
|
||||
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
${{ steps.release-meta.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', env.DOCKER_HUB_USERNAME) || '' }}
|
||||
${{ steps.release-meta.outputs.is_prerelease == 'false' && env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:latest', env.GHCR_OWNER) || '' }}
|
||||
${{ format('{0}/astrbot:{1}', env.DOCKER_HUB_USERNAME, steps.release-meta.outputs.version) }}
|
||||
${{ env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:{1}', env.GHCR_OWNER, steps.release-meta.outputs.version) || '' }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
run: echo "Release Docker image has been built and pushed successfully"
|
||||
|
||||
62
.gitignore
vendored
62
.gitignore
vendored
@@ -1,35 +1,49 @@
|
||||
# Python related
|
||||
__pycache__
|
||||
botpy.log
|
||||
.vscode
|
||||
.mypy_cache
|
||||
.venv*
|
||||
.idea
|
||||
data_v2.db
|
||||
data_v3.db
|
||||
configs/session
|
||||
configs/config.yaml
|
||||
**/.DS_Store
|
||||
temp
|
||||
cmd_config.json
|
||||
data
|
||||
cookies.json
|
||||
logs/
|
||||
addons/plugins
|
||||
.conda/
|
||||
uv.lock
|
||||
.coverage
|
||||
|
||||
# IDE and editors
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
# Logs and temporary files
|
||||
botpy.log
|
||||
logs/
|
||||
temp
|
||||
cookies.json
|
||||
|
||||
# Data files
|
||||
data_v2.db
|
||||
data_v3.db
|
||||
data
|
||||
configs/session
|
||||
configs/config.yaml
|
||||
cmd_config.json
|
||||
|
||||
# Plugins and packages
|
||||
addons/plugins
|
||||
packages/python_interpreter/workplace
|
||||
tests/astrbot_plugin_openai
|
||||
chroma
|
||||
|
||||
# Dashboard
|
||||
dashboard/node_modules/
|
||||
dashboard/dist/
|
||||
.DS_Store
|
||||
package-lock.json
|
||||
package.json
|
||||
venv/*
|
||||
packages/python_interpreter/workplace
|
||||
.venv/*
|
||||
.conda/
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
|
||||
uv.lock
|
||||
# Operating System
|
||||
**/.DS_Store
|
||||
.DS_Store
|
||||
|
||||
# AstrBot specific
|
||||
.astrbot
|
||||
astrbot.lock
|
||||
|
||||
# Other
|
||||
chroma
|
||||
venv/*
|
||||
pytest.ini
|
||||
|
||||
18
Dockerfile
18
Dockerfile
@@ -12,19 +12,21 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
bash \
|
||||
ffmpeg \
|
||||
curl \
|
||||
gnupg \
|
||||
git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
|
||||
RUN apt-get update && apt-get install -y curl gnupg && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && apt-get install -y curl gnupg \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
|
||||
&& apt-get install -y nodejs
|
||||
|
||||
RUN python -m pip install uv
|
||||
RUN python -m pip install uv \
|
||||
&& echo "3.11" > .python-version
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
CMD ["python", "main.py"]
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
FROM python:3.10-slim
|
||||
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
curl \
|
||||
unzip \
|
||||
ca-certificates \
|
||||
bash \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Installation of Node.js
|
||||
ENV NVM_DIR="/root/.nvm"
|
||||
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
||||
. "$NVM_DIR/nvm.sh" && \
|
||||
nvm install 22 && \
|
||||
nvm use 22
|
||||
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
|
||||
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
118
README.md
118
README.md
@@ -8,7 +8,7 @@
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
@@ -42,7 +42,7 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
|
||||
## 部署方式
|
||||
## 部署方式
|
||||
|
||||
#### Docker 部署(推荐 🥳)
|
||||
|
||||
@@ -119,83 +119,73 @@ uv run main.py
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
## 支持的消息平台
|
||||
|
||||
**官方维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方平台) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企微应用 | ✔ |
|
||||
| 企微智能机器人 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
| 钉钉 | ✔ |
|
||||
| Slack | ✔ |
|
||||
| Discord | ✔ |
|
||||
| Satori | ✔ |
|
||||
| Misskey | ✔ |
|
||||
| Whatsapp | 将支持 |
|
||||
| LINE | 将支持 |
|
||||
- QQ (官方平台 & OneBot)
|
||||
- Telegram
|
||||
- 企微应用 & 企微智能机器人
|
||||
- 微信客服 & 微信公众号
|
||||
- 飞书
|
||||
- 钉钉
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- Whatsapp (将支持)
|
||||
- LINE (将支持)
|
||||
|
||||
**社区维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
|
||||
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## ⚡ 提供商支持情况
|
||||
## 支持的模型服务
|
||||
|
||||
**大模型服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | |
|
||||
| Google Gemini | ✔ | |
|
||||
| Moonshot AI | ✔ | |
|
||||
| 智谱 AI | ✔ | |
|
||||
| DeepSeek | ✔ | |
|
||||
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
|
||||
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
|
||||
| 硅基流动 | ✔ | |
|
||||
| PPIO 派欧云 | ✔ | |
|
||||
| ModelScope | ✔ | |
|
||||
| OneAPI | ✔ | |
|
||||
| Dify | ✔ | |
|
||||
| 阿里云百炼应用 | ✔ | |
|
||||
| Coze | ✔ | |
|
||||
- OpenAI 及兼容服务
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智谱 AI
|
||||
- DeepSeek
|
||||
- Ollama (本地部署)
|
||||
- LM Studio (本地部署)
|
||||
- [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小马算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [硅基流动](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**LLMOps 平台**
|
||||
|
||||
- Dify
|
||||
- 阿里云百炼应用
|
||||
- Coze
|
||||
|
||||
**语音转文本服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| Whisper | ✔ | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 本地部署 |
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**文本转语音服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI TTS | ✔ | |
|
||||
| Gemini TTS | ✔ | |
|
||||
| GSVI | ✔ | GPT-Sovits-Inference |
|
||||
| GPT-SoVITs | ✔ | GPT-Sovits |
|
||||
| FishAudio | ✔ | |
|
||||
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | |
|
||||
| Azure TTS | ✔ | |
|
||||
| Minimax TTS | ✔ | |
|
||||
| 火山引擎 TTS | ✔ | |
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- 阿里云百炼 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- 火山引擎 TTS
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -229,7 +219,7 @@ pre-commit install
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
287
README_en.md
287
README_en.md
@@ -1,182 +1,233 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
|
||||
<br>
|
||||
|
||||
## ✨ Key Features
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
||||
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
||||
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
|
||||
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
||||
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
||||
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
||||
<br>
|
||||
|
||||
> [!TIP]
|
||||
> Dashboard Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
> Username: `astrbot`, Password: `astrbot` (LLM not configured for chat page)
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
|
||||
</div>
|
||||
|
||||
## ✨ Deployment
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform and development framework.
|
||||
|
||||
#### Docker Deployment
|
||||
## Key Features
|
||||
|
||||
See docs: [Deploy with Docker](https://astrbot.app/deploy/astrbot/docker.html#docker-deployment)
|
||||
1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more.
|
||||
2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation.
|
||||
3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more.
|
||||
4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem.
|
||||
5. **Web UI**. Visual configuration and management of your bot with comprehensive features.
|
||||
|
||||
#### Windows Installer
|
||||
## Deployment Methods
|
||||
|
||||
Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app/deploy/astrbot/windows.html)
|
||||
#### Docker Deployment (Recommended 🥳)
|
||||
|
||||
#### Replit Deployment
|
||||
We recommend deploying AstrBot using Docker or Docker Compose.
|
||||
|
||||
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### BT-Panel Deployment
|
||||
|
||||
AstrBot has partnered with BT-Panel and is now available in their marketplace.
|
||||
|
||||
Please refer to the official documentation: [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### 1Panel Deployment
|
||||
|
||||
AstrBot has been officially listed on the 1Panel marketplace.
|
||||
|
||||
Please refer to the official documentation: [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Deploy on RainYun
|
||||
|
||||
AstrBot has been officially listed on RainYun's cloud application platform with one-click deployment.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Deploy on Replit
|
||||
|
||||
Community-contributed deployment method.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows One-Click Installer
|
||||
|
||||
Please refer to the official documentation: [Deploy AstrBot with Windows One-Click Installer](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### CasaOS Deployment
|
||||
|
||||
Community-contributed method.
|
||||
See docs: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html)
|
||||
Community-contributed deployment method.
|
||||
|
||||
Please refer to the official documentation: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Manual Deployment
|
||||
|
||||
See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
|
||||
First, install uv:
|
||||
|
||||
## ⚡ Platform Support
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
| Platform | Status | Details | Message Types |
|
||||
| -------------------------------------------------------------- | ------ | ------------------- | ------------------- |
|
||||
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
|
||||
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
|
||||
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
|
||||
| Feishu | ✔ | Group chats | Text, Images |
|
||||
| WeChat Open Platform | 🚧 | Planned | - |
|
||||
| Discord | 🚧 | Planned | - |
|
||||
| WhatsApp | 🚧 | Planned | - |
|
||||
| Xiaomi Speakers | 🚧 | Planned | - |
|
||||
Install AstrBot via Git Clone:
|
||||
|
||||
## Provider Support Status
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
| Name | Support | Type | Notes |
|
||||
|---------------------------|---------|------------------------|-----------------------------------------------------------------------|
|
||||
| OpenAI API | ✔ | Text Generation | Supports all OpenAI API-compatible services including DeepSeek, Google Gemini, GLM, Moonshot, Alibaba Cloud Bailian, Silicon Flow, xAI, etc. |
|
||||
| Claude API | ✔ | Text Generation | |
|
||||
| Google Gemini API | ✔ | Text Generation | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| DashScope (Alibaba Cloud) | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
|
||||
| LM Studio | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
|
||||
| LLMTuner | ✔ | Model Loader | Local loading of fine-tuned models (e.g. LoRA) |
|
||||
| OneAPI | ✔ | LLM Distribution | |
|
||||
| Whisper | ✔ | Speech-to-Text | Supports API and local deployment |
|
||||
| SenseVoice | ✔ | Speech-to-Text | Local deployment |
|
||||
| OpenAI TTS API | ✔ | Text-to-Speech | |
|
||||
| Fishaudio | ✔ | Text-to-Speech | Project involving GPT-Sovits author |
|
||||
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
# 🦌 Roadmap
|
||||
## 🌍 Community
|
||||
|
||||
> [!TIP]
|
||||
> Suggestions welcome via Issues <3
|
||||
### QQ Groups
|
||||
|
||||
- [ ] Ensure feature parity across all platform adapters
|
||||
- [ ] Optimize plugin APIs
|
||||
- [ ] Add default TTS services (e.g., GPT-Sovits)
|
||||
- [ ] Enhance chat features with persistent memory
|
||||
- [ ] i18n Planning
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
## ❤️ Contributions
|
||||
### Telegram Group
|
||||
|
||||
All Issues/PRs welcome! Simply submit your changes to this project :)
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
For major features, please discuss via Issues first.
|
||||
### Discord Server
|
||||
|
||||
## 🌟 Support
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
- Star this project!
|
||||
- Support via [Afdian](https://afdian.com/a/soulter)
|
||||
- WeChat support: [QR Code](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)
|
||||
## Supported Messaging Platforms
|
||||
|
||||
## ✨ Demos
|
||||
**Officially Maintained**
|
||||
|
||||
> [!NOTE]
|
||||
> Code executor file I/O currently tested with Napcat(QQ)/Lagrange(QQ)
|
||||
- QQ (Official Platform & OneBot)
|
||||
- Telegram
|
||||
- WeChat Work Application & WeChat Work Intelligent Bot
|
||||
- WeChat Customer Service & WeChat Official Accounts
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Coming Soon)
|
||||
- LINE (Coming Soon)
|
||||
|
||||
<div align='center'>
|
||||
**Community Maintained**
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
_✨ Docker-based Sandboxed Code Executor (Beta) ✨_
|
||||
## Supported Model Services
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
**LLM Services**
|
||||
|
||||
_✨ Multimodal Input, Web Search, Text-to-Image ✨_
|
||||
- OpenAI and Compatible Services
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Self-hosted)
|
||||
- LM Studio (Self-hosted)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
**LLMOps Platforms**
|
||||
|
||||
_✨ Natural Language TODO Lists ✨_
|
||||
- Dify
|
||||
- Alibaba Cloud Bailian Applications
|
||||
- Coze
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
**Speech-to-Text Services**
|
||||
|
||||
_✨ Plugin System Showcase ✨_
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
||||
**Text-to-Speech Services**
|
||||
|
||||
_✨ Web Dashboard ✨_
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||

|
||||
## ❤️ Contributing
|
||||
|
||||
_✨ Built-in Web Chat Interface ✨_
|
||||
Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
|
||||
|
||||
</div>
|
||||
### How to Contribute
|
||||
|
||||
You can contribute by reviewing issues or helping with pull request reviews. Any issues or PRs are welcome to encourage community participation. Of course, these are just suggestions—you can contribute in any way you like. For adding new features, please discuss through an Issue first.
|
||||
|
||||
### Development Environment
|
||||
|
||||
AstrBot uses `ruff` for code formatting and linting.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
Additionally, the birth of this project would not have been possible without the help of the following open-source projects:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - The amazing cat framework
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> If this project helps you, please give it a star <3
|
||||
> [!TIP]
|
||||
> If this project has helped you in your life or work, or if you're interested in its future development, please give the project a Star. It's the driving force behind maintaining this open-source project <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#AstrBotDevs/AstrBot&Date)
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. Licensed under `AGPL-v3`.
|
||||
2. WeChat integration uses [Gewechat](https://github.com/Devo919/Gewechat). Use at your own risk with non-critical accounts.
|
||||
3. Users must comply with local laws and regulations.
|
||||
|
||||
<!-- ## ✨ ATRI [Beta]
|
||||
|
||||
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri)
|
||||
|
||||
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
|
||||
2. Long-term memory
|
||||
3. Meme understanding & responses
|
||||
4. TTS integration
|
||||
-->
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
270
README_ja.md
270
README_ja.md
@@ -1,167 +1,233 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_✨ 簡単に使えるマルチプラットフォーム LLM チャットボットおよび開発フレームワーク ✨_
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||
|
||||
<a href="https://astrbot.app/">ドキュメントを見る</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
|
||||
<br>
|
||||
|
||||
## ✨ 主な機能
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
||||
<br>
|
||||
|
||||
> [!TIP]
|
||||
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://astrbot.app/">ドキュメント</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue</a>
|
||||
</div>
|
||||
|
||||
## ✨ 使用方法
|
||||
AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークです。
|
||||
|
||||
#### Docker デプロイ
|
||||
## 主な機能
|
||||
|
||||
公式ドキュメント [Docker を使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) を参照してください。
|
||||
1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載。
|
||||
2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート。
|
||||
3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。
|
||||
4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム。
|
||||
5. **WebUI**。ビジュアル設定とボット管理、充実した機能。
|
||||
|
||||
#### Windows ワンクリックインストーラーのデプロイ
|
||||
## デプロイ方法
|
||||
|
||||
コンピュータに Python(>3.10)がインストールされている必要があります。公式ドキュメント [Windows ワンクリックインストーラーを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/windows.html) を参照してください。
|
||||
#### Docker デプロイ(推奨 🥳)
|
||||
|
||||
#### Replit デプロイ
|
||||
Docker / Docker Compose を使用した AstrBot のデプロイを推奨します。
|
||||
|
||||
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
|
||||
|
||||
#### 宝塔パネルデプロイ
|
||||
|
||||
AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
|
||||
|
||||
公式ドキュメント [宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html) をご参照ください。
|
||||
|
||||
#### 1Panel デプロイ
|
||||
|
||||
AstrBot は 1Panel 公式により 1Panel パネルに公開されています。
|
||||
|
||||
公式ドキュメント [1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html) をご参照ください。
|
||||
|
||||
#### 雨云でのデプロイ
|
||||
|
||||
AstrBot は雨云公式によりクラウドアプリケーションプラットフォームに公開され、ワンクリックでデプロイ可能です。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Replit でのデプロイ
|
||||
|
||||
コミュニティ貢献によるデプロイ方法。
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows ワンクリックインストーラーデプロイ
|
||||
|
||||
公式ドキュメント [Windows ワンクリックインストーラーを使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/windows.html) をご参照ください。
|
||||
|
||||
#### CasaOS デプロイ
|
||||
|
||||
コミュニティが提供するデプロイ方法です。
|
||||
コミュニティ貢献によるデプロイ方法。
|
||||
|
||||
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/casaos.html) を参照してください。
|
||||
公式ドキュメント [CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html) をご参照ください。
|
||||
|
||||
#### 手動デプロイ
|
||||
|
||||
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/cli.html) を参照してください。
|
||||
まず uv をインストールします:
|
||||
|
||||
## ⚡ メッセージプラットフォームのサポート状況
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
| プラットフォーム | サポート状況 | 詳細 | メッセージタイプ |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(公式ロボットインターフェース) | ✔ | プライベートチャット、グループチャット、QQ チャンネルプライベートチャット、グループチャット | テキスト、画像 |
|
||||
| QQ(OneBot) | ✔ | プライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||
| WeChat(個人アカウント) | ✔ | WeChat 個人アカウントのプライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | プライベートチャット、グループチャット | テキスト、画像 |
|
||||
| [WeChat(企業 WeChat)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | プライベートチャット | テキスト、画像、音声 |
|
||||
| Feishu | ✔ | グループチャット | テキスト、画像 |
|
||||
| WeChat 対話オープンプラットフォーム | 🚧 | 計画中 | - |
|
||||
| Discord | 🚧 | 計画中 | - |
|
||||
| WhatsApp | 🚧 | 計画中 | - |
|
||||
| Xiaoai 音響 | 🚧 | 計画中 | - |
|
||||
Git Clone で AstrBot をインストール:
|
||||
|
||||
# 🦌 今後のロードマップ
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Issue でさらに多くの提案を歓迎します <3
|
||||
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
|
||||
|
||||
- [ ] 現在のすべてのプラットフォームアダプターの機能の一貫性を確保し、改善する
|
||||
- [ ] プラグインインターフェースの最適化
|
||||
- [ ] GPT-Sovits などの TTS サービスをデフォルトでサポート
|
||||
- [ ] "チャット強化" 部分を完成させ、永続的な記憶をサポート
|
||||
- [ ] i18n の計画
|
||||
## 🌍 コミュニティ
|
||||
|
||||
## ❤️ 貢献
|
||||
### QQ グループ
|
||||
|
||||
Issue や Pull Request を歓迎します!このプロジェクトに変更を加えるだけです :)
|
||||
- 1群:322154837
|
||||
- 3群:630166526
|
||||
- 5群:822130018
|
||||
- 6群:753075035
|
||||
- 開発者群:975206796
|
||||
|
||||
新機能の追加については、まず Issue で議論してください。
|
||||
### Telegram グループ
|
||||
|
||||
## 🌟 サポート
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
- このプロジェクトに Star を付けてください!
|
||||
- [愛発電](https://afdian.com/a/soulter)で私をサポートしてください!
|
||||
- [WeChat](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)で私をサポートしてください~
|
||||
### Discord サーバー
|
||||
|
||||
## ✨ デモ
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
> [!NOTE]
|
||||
> コードエグゼキューターのファイル入力/出力は現在 Napcat(QQ)、Lagrange(QQ) でのみテストされています
|
||||
## サポートされているメッセージプラットフォーム
|
||||
|
||||
<div align='center'>
|
||||
**公式メンテナンス**
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
- QQ (公式プラットフォーム & OneBot)
|
||||
- Telegram
|
||||
- WeChat Work アプリケーション & WeChat Work インテリジェントボット
|
||||
- WeChat カスタマーサービス & WeChat 公式アカウント
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (近日対応予定)
|
||||
- LINE (近日対応予定)
|
||||
|
||||
_✨ Docker ベースのサンドボックス化されたコードエグゼキューター(ベータテスト中)✨_
|
||||
**コミュニティメンテナンス**
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
_✨ 多モーダル、ウェブ検索、長文の画像変換(設定可能)✨_
|
||||
## サポートされているモデルサービス
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
**大規模言語モデルサービス**
|
||||
|
||||
_✨ 自然言語タスク ✨_
|
||||
- OpenAI および互換サービス
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智谱 AI
|
||||
- DeepSeek
|
||||
- Ollama (セルフホスト)
|
||||
- LM Studio (セルフホスト)
|
||||
- [優云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小馬算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [硅基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
**LLMOps プラットフォーム**
|
||||
|
||||
_✨ プラグインシステム - 一部のプラグインの展示 ✨_
|
||||
- Dify
|
||||
- Alibaba Cloud 百炼アプリケーション
|
||||
- Coze
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width="600">
|
||||
**音声認識サービス**
|
||||
|
||||
_✨ 管理パネル ✨_
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||

|
||||
**音声合成サービス**
|
||||
|
||||
_✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud 百炼 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
</div>
|
||||
## ❤️ コントリビューション
|
||||
|
||||
Issue や Pull Request は大歓迎です!このプロジェクトに変更を送信してください :)
|
||||
|
||||
### コントリビュート方法
|
||||
|
||||
Issue を確認したり、PR(プルリクエスト)のレビューを手伝うことで貢献できます。どんな Issue や PR への参加も歓迎され、コミュニティ貢献を促進します。もちろん、これらは提案に過ぎず、どんな方法でも貢献できます。新機能の追加については、まず Issue で議論してください。
|
||||
|
||||
### 開発環境
|
||||
|
||||
AstrBot はコードのフォーマットとチェックに `ruff` を使用しています。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
また、このプロジェクトの誕生は以下のオープンソースプロジェクトの助けなしには実現できませんでした:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 素晴らしい猫猫フレームワーク
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
|
||||
> このプロジェクトがあなたの生活や仕事に役立ったり、このプロジェクトの今後の発展に関心がある場合は、プロジェクトに Star をください。これがこのオープンソースプロジェクトを維持する原動力です <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
## スポンサー
|
||||
|
||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||
|
||||
## 免責事項
|
||||
|
||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
|
||||
<!-- ## ✨ ATRI [ベータテスト]
|
||||
|
||||
この機能はプラグインとしてロードされます。プラグインリポジトリのアドレス:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||
|
||||
1. 《ATRI ~ My Dear Moments》の主人公 ATRI のキャラクターセリフを微調整データセットとして使用した `Qwen1.5-7B-Chat Lora` 微調整モデル。
|
||||
2. 長期記憶
|
||||
3. ミームの理解と返信
|
||||
4. TTS
|
||||
-->
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_agent as agent
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
from astrbot.core import html_renderer, sp
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.register import register_agent as agent
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"BaseFunctionToolExecutor",
|
||||
"FunctionTool",
|
||||
"ToolSet",
|
||||
"agent",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"agent",
|
||||
"logger",
|
||||
"sp",
|
||||
"ToolSet",
|
||||
"FunctionTool",
|
||||
"BaseFunctionToolExecutor",
|
||||
]
|
||||
|
||||
@@ -36,7 +36,8 @@ from astrbot.core.star.config import *
|
||||
|
||||
|
||||
# provider
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
from astrbot.core.provider import Provider, ProviderMetaData
|
||||
from astrbot.core.db.po import Personality
|
||||
|
||||
# platform
|
||||
from astrbot.core.platform import (
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
MessageChain,
|
||||
CommandResult,
|
||||
EventResultType,
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
__all__ = [
|
||||
"MessageEventResult",
|
||||
"MessageChain",
|
||||
"AstrMessageEvent",
|
||||
"CommandResult",
|
||||
"EventResultType",
|
||||
"AstrMessageEvent",
|
||||
"MessageChain",
|
||||
"MessageEventResult",
|
||||
"ResultContentType",
|
||||
]
|
||||
|
||||
@@ -1,51 +1,52 @@
|
||||
from astrbot.core.star.register import (
|
||||
register_command as command,
|
||||
register_command_group as command_group,
|
||||
register_event_message_type as event_message_type,
|
||||
register_regex as regex,
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
register_permission_type as permission_type,
|
||||
register_custom_filter as custom_filter,
|
||||
register_on_astrbot_loaded as on_astrbot_loaded,
|
||||
register_on_platform_loaded as on_platform_loaded,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_on_llm_response as on_llm_response,
|
||||
register_llm_tool as llm_tool,
|
||||
register_on_decorating_result as on_decorating_result,
|
||||
register_after_message_sent as after_message_sent,
|
||||
)
|
||||
|
||||
from astrbot.core.star.filter.event_message_type import (
|
||||
EventMessageTypeFilter,
|
||||
EventMessageType,
|
||||
)
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterTypeFilter,
|
||||
PlatformAdapterType,
|
||||
)
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
|
||||
from astrbot.core.star.filter.custom_filter import CustomFilter
|
||||
from astrbot.core.star.filter.event_message_type import (
|
||||
EventMessageType,
|
||||
EventMessageTypeFilter,
|
||||
)
|
||||
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
|
||||
from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
PlatformAdapterTypeFilter,
|
||||
)
|
||||
from astrbot.core.star.register import register_after_message_sent as after_message_sent
|
||||
from astrbot.core.star.register import register_command as command
|
||||
from astrbot.core.star.register import register_command_group as command_group
|
||||
from astrbot.core.star.register import register_custom_filter as custom_filter
|
||||
from astrbot.core.star.register import register_event_message_type as event_message_type
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_on_astrbot_loaded as on_astrbot_loaded
|
||||
from astrbot.core.star.register import (
|
||||
register_on_decorating_result as on_decorating_result,
|
||||
)
|
||||
from astrbot.core.star.register import register_on_llm_request as on_llm_request
|
||||
from astrbot.core.star.register import register_on_llm_response as on_llm_response
|
||||
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
|
||||
from astrbot.core.star.register import register_permission_type as permission_type
|
||||
from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
)
|
||||
from astrbot.core.star.register import register_regex as regex
|
||||
|
||||
__all__ = [
|
||||
"CustomFilter",
|
||||
"EventMessageType",
|
||||
"EventMessageTypeFilter",
|
||||
"PermissionType",
|
||||
"PermissionTypeFilter",
|
||||
"PlatformAdapterType",
|
||||
"PlatformAdapterTypeFilter",
|
||||
"after_message_sent",
|
||||
"command",
|
||||
"command_group",
|
||||
"event_message_type",
|
||||
"regex",
|
||||
"platform_adapter_type",
|
||||
"permission_type",
|
||||
"EventMessageTypeFilter",
|
||||
"EventMessageType",
|
||||
"PlatformAdapterTypeFilter",
|
||||
"PlatformAdapterType",
|
||||
"PermissionTypeFilter",
|
||||
"CustomFilter",
|
||||
"custom_filter",
|
||||
"PermissionType",
|
||||
"on_astrbot_loaded",
|
||||
"on_platform_loaded",
|
||||
"on_llm_request",
|
||||
"event_message_type",
|
||||
"llm_tool",
|
||||
"on_astrbot_loaded",
|
||||
"on_decorating_result",
|
||||
"after_message_sent",
|
||||
"on_llm_request",
|
||||
"on_llm_response",
|
||||
"on_platform_loaded",
|
||||
"permission_type",
|
||||
"platform_adapter_type",
|
||||
"regex",
|
||||
]
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
from astrbot.core.message.components import *
|
||||
from astrbot.core.platform import (
|
||||
AstrMessageEvent,
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
AstrMessageEvent,
|
||||
Group,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
Group,
|
||||
)
|
||||
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
from astrbot.core.message.components import *
|
||||
|
||||
__all__ = [
|
||||
"AstrMessageEvent",
|
||||
"Platform",
|
||||
"AstrBotMessage",
|
||||
"AstrMessageEvent",
|
||||
"Group",
|
||||
"MessageMember",
|
||||
"MessageType",
|
||||
"Platform",
|
||||
"PlatformMetadata",
|
||||
"register_platform_adapter",
|
||||
"Group",
|
||||
]
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider import Provider, STTProvider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderMetaData,
|
||||
ProviderRequest,
|
||||
ProviderType,
|
||||
ProviderMetaData,
|
||||
LLMResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Provider",
|
||||
"STTProvider",
|
||||
"LLMResponse",
|
||||
"Personality",
|
||||
"Provider",
|
||||
"ProviderMetaData",
|
||||
"ProviderRequest",
|
||||
"ProviderType",
|
||||
"ProviderMetaData",
|
||||
"LLMResponse",
|
||||
"STTProvider",
|
||||
]
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from astrbot.core.star import Context, Star, StarTools
|
||||
from astrbot.core.star.config import *
|
||||
from astrbot.core.star.register import (
|
||||
register_star as register, # 注册插件(Star)
|
||||
)
|
||||
|
||||
from astrbot.core.star import Context, Star, StarTools
|
||||
from astrbot.core.star.config import *
|
||||
|
||||
__all__ = ["register", "Context", "Star", "StarTools"]
|
||||
__all__ = ["Context", "Star", "StarTools", "register"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from astrbot.core.utils.session_waiter import (
|
||||
SessionWaiter,
|
||||
SessionController,
|
||||
SessionWaiter,
|
||||
session_waiter,
|
||||
)
|
||||
|
||||
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]
|
||||
__all__ = ["SessionController", "SessionWaiter", "session_waiter"]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
AstrBot CLI入口
|
||||
"""
|
||||
"""AstrBot CLI入口"""
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
import sys
|
||||
|
||||
from . import __version__
|
||||
from .commands import init, run, plug, conf
|
||||
from .commands import conf, init, plug, run
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .cmd_init import init
|
||||
from .cmd_run import run
|
||||
from .cmd_plug import plug
|
||||
from .cmd_conf import conf
|
||||
from .cmd_init import init
|
||||
from .cmd_plug import plug
|
||||
from .cmd_run import run
|
||||
|
||||
__all__ = ["init", "run", "plug", "conf"]
|
||||
__all__ = ["conf", "init", "plug", "run"]
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import json
|
||||
import click
|
||||
import hashlib
|
||||
import json
|
||||
import zoneinfo
|
||||
from typing import Any, Callable
|
||||
from ..utils import get_astrbot_root, check_astrbot_root
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
from ..utils import check_astrbot_root, get_astrbot_root
|
||||
|
||||
|
||||
def _validate_log_level(value: str) -> str:
|
||||
@@ -11,7 +14,7 @@ def _validate_log_level(value: str) -> str:
|
||||
value = value.upper()
|
||||
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise click.ClickException(
|
||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
|
||||
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
|
||||
)
|
||||
return value
|
||||
|
||||
@@ -73,7 +76,7 @@ def _load_config() -> dict[str, Any]:
|
||||
root = get_astrbot_root()
|
||||
if not check_astrbot_root(root):
|
||||
raise click.ClickException(
|
||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||
)
|
||||
|
||||
config_path = root / "data" / "cmd_config.json"
|
||||
@@ -88,7 +91,7 @@ def _load_config() -> dict[str, Any]:
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8-sig"))
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.ClickException(f"配置文件解析失败: {str(e)}")
|
||||
raise click.ClickException(f"配置文件解析失败: {e!s}")
|
||||
|
||||
|
||||
def _save_config(config: dict[str, Any]) -> None:
|
||||
@@ -96,7 +99,8 @@ def _save_config(config: dict[str, Any]) -> None:
|
||||
config_path = get_astrbot_root() / "data" / "cmd_config.json"
|
||||
|
||||
config_path.write_text(
|
||||
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
|
||||
json.dumps(config, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8-sig",
|
||||
)
|
||||
|
||||
|
||||
@@ -108,7 +112,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
|
||||
obj[part] = {}
|
||||
elif not isinstance(obj[part], dict):
|
||||
raise click.ClickException(
|
||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
|
||||
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典",
|
||||
)
|
||||
obj = obj[part]
|
||||
obj[parts[-1]] = value
|
||||
@@ -140,7 +144,6 @@ def conf():
|
||||
|
||||
- callback_api_base: 回调接口基址
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@conf.command(name="set")
|
||||
@@ -148,7 +151,7 @@ def conf():
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str):
|
||||
"""设置配置项的值"""
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
if key not in CONFIG_VALIDATORS:
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
config = _load_config()
|
||||
@@ -170,17 +173,17 @@ def set_config(key: str, value: str):
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"设置配置失败: {str(e)}")
|
||||
raise click.UsageError(f"设置配置失败: {e!s}")
|
||||
|
||||
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str = None):
|
||||
def get_config(key: str | None = None):
|
||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||
config = _load_config()
|
||||
|
||||
if key:
|
||||
if key not in CONFIG_VALIDATORS.keys():
|
||||
if key not in CONFIG_VALIDATORS:
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
|
||||
try:
|
||||
@@ -191,10 +194,10 @@ def get_config(key: str = None):
|
||||
except KeyError:
|
||||
raise click.ClickException(f"未知的配置项: {key}")
|
||||
except Exception as e:
|
||||
raise click.UsageError(f"获取配置失败: {str(e)}")
|
||||
raise click.UsageError(f"获取配置失败: {e!s}")
|
||||
else:
|
||||
click.echo("当前配置:")
|
||||
for key in CONFIG_VALIDATORS.keys():
|
||||
for key in CONFIG_VALIDATORS:
|
||||
try:
|
||||
value = (
|
||||
"********"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
@@ -6,14 +7,14 @@ from filelock import FileLock, Timeout
|
||||
from ..utils import check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def initialize_astrbot(astrbot_root) -> None:
|
||||
async def initialize_astrbot(astrbot_root: Path) -> None:
|
||||
"""执行 AstrBot 初始化逻辑"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(f"Current Directory: {astrbot_root}")
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。",
|
||||
)
|
||||
if click.confirm(
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
|
||||
@@ -1,31 +1,29 @@
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import shutil
|
||||
|
||||
|
||||
from ..utils import (
|
||||
get_git_repo,
|
||||
build_plug_list,
|
||||
manage_plugin,
|
||||
PluginStatus,
|
||||
build_plug_list,
|
||||
check_astrbot_root,
|
||||
get_astrbot_root,
|
||||
get_git_repo,
|
||||
manage_plugin,
|
||||
)
|
||||
|
||||
|
||||
@click.group()
|
||||
def plug():
|
||||
"""插件管理"""
|
||||
pass
|
||||
|
||||
|
||||
def _get_data_path() -> Path:
|
||||
base = get_astrbot_root()
|
||||
if not check_astrbot_root(base):
|
||||
raise click.ClickException(
|
||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||
)
|
||||
return (base / "data").resolve()
|
||||
|
||||
@@ -41,7 +39,7 @@ def display_plugins(plugins, title=None, color=None):
|
||||
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
|
||||
click.echo(
|
||||
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
|
||||
f"{p['author']:<15} {desc:<30}"
|
||||
f"{p['author']:<15} {desc:<30}",
|
||||
)
|
||||
|
||||
|
||||
@@ -78,7 +76,7 @@ def new(name: str):
|
||||
f"desc: {desc}\n"
|
||||
f"version: {version}\n"
|
||||
f"author: {author}\n"
|
||||
f"repo: {repo}\n"
|
||||
f"repo: {repo}\n",
|
||||
)
|
||||
|
||||
# 重写 README.md
|
||||
@@ -86,7 +84,7 @@ def new(name: str):
|
||||
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
|
||||
|
||||
# 重写 main.py
|
||||
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
|
||||
with open(plug_path / "main.py", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
new_content = content.replace(
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
|
||||
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import logger, LogManager, LogBroker, db_helper
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
await check_dashboard(astrbot_root / "data")
|
||||
@@ -38,7 +37,7 @@ def run(reload: bool, port: str) -> None:
|
||||
|
||||
if not check_astrbot_root(astrbot_root):
|
||||
raise click.ClickException(
|
||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
|
||||
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
|
||||
)
|
||||
|
||||
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from .basic import (
|
||||
get_astrbot_root,
|
||||
check_astrbot_root,
|
||||
check_dashboard,
|
||||
get_astrbot_root,
|
||||
)
|
||||
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
|
||||
from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
__all__ = [
|
||||
"get_astrbot_root",
|
||||
"PluginStatus",
|
||||
"VersionComparator",
|
||||
"build_plug_list",
|
||||
"check_astrbot_root",
|
||||
"check_dashboard",
|
||||
"get_astrbot_root",
|
||||
"get_git_repo",
|
||||
"manage_plugin",
|
||||
"build_plug_list",
|
||||
"VersionComparator",
|
||||
"PluginStatus",
|
||||
]
|
||||
|
||||
@@ -21,8 +21,9 @@ def get_astrbot_root() -> Path:
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
"""检查是否安装了dashboard"""
|
||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
try:
|
||||
@@ -48,19 +49,18 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
||||
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
|
||||
click.echo("管理面板已是最新版本")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("初始化管理面板目录...")
|
||||
try:
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import click
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from .version_comparator import VersionComparator
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
|
||||
try:
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
proxy=proxy if proxy else None,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
resp = client.get(release_url)
|
||||
resp.raise_for_status()
|
||||
@@ -55,7 +56,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
|
||||
# 下载并解压
|
||||
with httpx.Client(
|
||||
proxy=proxy if proxy else None, follow_redirects=True
|
||||
proxy=proxy if proxy else None,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
resp = client.get(download_url)
|
||||
if (
|
||||
@@ -89,6 +91,7 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
|
||||
|
||||
Returns:
|
||||
dict: 包含元数据的字典,如果读取失败则返回空字典
|
||||
|
||||
"""
|
||||
yaml_path = plugin_dir / "metadata.yaml"
|
||||
if yaml_path.exists():
|
||||
@@ -107,6 +110,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
|
||||
Returns:
|
||||
list: 包含插件信息的字典列表
|
||||
|
||||
"""
|
||||
# 获取本地插件信息
|
||||
result = []
|
||||
@@ -133,7 +137,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
"repo": str(metadata.get("repo", "")),
|
||||
"status": PluginStatus.INSTALLED,
|
||||
"local_path": str(plugin_dir),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# 获取在线插件列表
|
||||
@@ -153,7 +157,7 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
"repo": str(plugin_info.get("repo", "")),
|
||||
"status": PluginStatus.NOT_INSTALLED,
|
||||
"local_path": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||
@@ -168,7 +172,8 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
)
|
||||
if (
|
||||
VersionComparator.compare_version(
|
||||
local_plugin["version"], online_plugin["version"]
|
||||
local_plugin["version"],
|
||||
online_plugin["version"],
|
||||
)
|
||||
< 0
|
||||
):
|
||||
@@ -186,7 +191,10 @@ def build_plug_list(plugins_dir: Path) -> list:
|
||||
|
||||
|
||||
def manage_plugin(
|
||||
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
|
||||
plugin: dict,
|
||||
plugins_dir: Path,
|
||||
is_update: bool = False,
|
||||
proxy: str | None = None,
|
||||
) -> None:
|
||||
"""安装或更新插件
|
||||
|
||||
@@ -195,6 +203,7 @@ def manage_plugin(
|
||||
plugins_dir (Path): 插件目录
|
||||
is_update (bool, optional): 是否为更新操作. 默认为 False
|
||||
proxy (str, optional): 代理服务器地址
|
||||
|
||||
"""
|
||||
plugin_name = plugin["name"]
|
||||
repo_url = plugin["repo"]
|
||||
@@ -212,26 +221,26 @@ def manage_plugin(
|
||||
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
|
||||
|
||||
# 备份现有插件
|
||||
if is_update and backup_path.exists():
|
||||
if is_update and backup_path is not None and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
if is_update:
|
||||
if is_update and backup_path is not None:
|
||||
shutil.copytree(target_path, backup_path)
|
||||
|
||||
try:
|
||||
click.echo(
|
||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
|
||||
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...",
|
||||
)
|
||||
get_git_repo(repo_url, target_path, proxy)
|
||||
|
||||
# 更新成功,删除备份
|
||||
if is_update and backup_path.exists():
|
||||
if is_update and backup_path is not None and backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
|
||||
except Exception as e:
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
if is_update and backup_path.exists():
|
||||
if is_update and backup_path is not None and backup_path.exists():
|
||||
shutil.move(backup_path, target_path)
|
||||
raise click.ClickException(
|
||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
|
||||
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
拷贝自 astrbot.core.utils.version_comparator
|
||||
"""
|
||||
"""拷贝自 astrbot.core.utils.version_comparator"""
|
||||
|
||||
import re
|
||||
|
||||
@@ -42,15 +40,15 @@ class VersionComparator:
|
||||
for i in range(length):
|
||||
if v1_parts[i] > v2_parts[i]:
|
||||
return 1
|
||||
elif v1_parts[i] < v2_parts[i]:
|
||||
if v1_parts[i] < v2_parts[i]:
|
||||
return -1
|
||||
|
||||
# 比较预发布标签
|
||||
if v1_prerelease is None and v2_prerelease is not None:
|
||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is None:
|
||||
if v1_prerelease is not None and v2_prerelease is None:
|
||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is not None:
|
||||
if v1_prerelease is not None and v2_prerelease is not None:
|
||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||
for i in range(len_pre):
|
||||
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||
@@ -58,21 +56,21 @@ class VersionComparator:
|
||||
|
||||
if p1 is None and p2 is not None:
|
||||
return -1
|
||||
elif p1 is not None and p2 is None:
|
||||
if p1 is not None and p2 is None:
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, str):
|
||||
if isinstance(p1, int) and isinstance(p2, str):
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, int):
|
||||
if isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, int):
|
||||
if isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
if p1 < p2:
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
if p1 < p2:
|
||||
return -1
|
||||
return 0 # 预发布标签完全相同
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import os
|
||||
from .log import LogManager, LogBroker # noqa
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.file_token_service import FileTokenService
|
||||
from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
|
||||
from .log import LogBroker, LogManager # noqa
|
||||
from .utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from .tool import FunctionTool
|
||||
from typing import Generic
|
||||
from .run_context import TContext
|
||||
|
||||
from .hooks import BaseAgentRunHooks
|
||||
from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
from typing import Generic
|
||||
from .tool import FunctionTool
|
||||
|
||||
from .agent import Agent
|
||||
from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
class HandoffTool(FunctionTool, Generic[TContext]):
|
||||
"""Handoff tool for delegating tasks to another agent."""
|
||||
|
||||
def __init__(
|
||||
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
|
||||
self,
|
||||
agent: Agent[TContext],
|
||||
parameters: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.agent = agent
|
||||
super().__init__(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import mcp
|
||||
from dataclasses import dataclass
|
||||
from .run_context import ContextWrapper, TContext
|
||||
from typing import Generic
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
import mcp
|
||||
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseAgentRunHooks(Generic[TContext]):
|
||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
||||
async def on_tool_start(
|
||||
@@ -23,5 +24,7 @@ class BaseAgentRunHooks(Generic[TContext]):
|
||||
tool_result: mcp.types.CallToolResult | None,
|
||||
): ...
|
||||
async def on_agent_done(
|
||||
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
llm_response: LLMResponse,
|
||||
): ...
|
||||
|
||||
@@ -1,28 +1,44 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from typing import Generic
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
try:
|
||||
import anyio
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
logger.warning(
|
||||
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
||||
)
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
||||
)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
"""Prepare configuration, handle nested format"""
|
||||
if config.get("mcpServers"):
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
@@ -30,7 +46,7 @@ def _prepare_config(config: dict) -> dict:
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
"""Quick test MCP server connectivity"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
@@ -45,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if transport_type == "streamable_http":
|
||||
@@ -71,8 +87,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
@@ -84,11 +99,10 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
return False, f"Connection timeout: {timeout} seconds"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
@@ -96,8 +110,9 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.session: mcp.ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
||||
|
||||
self.name: str | None = None
|
||||
self.active: bool = True
|
||||
@@ -105,21 +120,32 @@ class MCPClient:
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
# Store connection config for reconnection
|
||||
self._mcp_server_config: dict | None = None
|
||||
self._server_name: str | None = None
|
||||
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
||||
self._reconnecting: bool = False # For logging and debugging
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""Connect to MCP server
|
||||
|
||||
If `url` parameter exists:
|
||||
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
||||
2. When transport is specified as `sse`, use SSE connection.
|
||||
3. If not specified, default to SSE connection to MCP service.
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
|
||||
"""
|
||||
# Store config for reconnection
|
||||
self._mcp_server_config = mcp_server_config
|
||||
self._server_name = name
|
||||
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
# Handle MCP service error logs
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
@@ -133,7 +159,7 @@ class MCPClient:
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
if transport_type != "streamable_http":
|
||||
# SSE transport method
|
||||
@@ -144,7 +170,7 @@ class MCPClient:
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
self._streams_context,
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
@@ -154,12 +180,12 @@ class MCPClient:
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
@@ -169,7 +195,7 @@ class MCPClient:
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
self._streams_context,
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
@@ -180,7 +206,7 @@ class MCPClient:
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -189,7 +215,7 @@ class MCPClient:
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
# Handle MCP service error logs
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
@@ -206,7 +232,7 @@ class MCPClient:
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
mcp.ClientSession(*stdio_transport),
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
@@ -218,7 +244,142 @@ class MCPClient:
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server using the stored configuration.
|
||||
|
||||
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
||||
|
||||
Raises:
|
||||
Exception: raised when reconnection fails
|
||||
"""
|
||||
async with self._reconnect_lock:
|
||||
# Check if already reconnecting (useful for logging)
|
||||
if self._reconnecting:
|
||||
logger.debug(
|
||||
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
if not self._mcp_server_config or not self._server_name:
|
||||
raise Exception("Cannot reconnect: missing connection configuration")
|
||||
|
||||
self._reconnecting = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Attempting to reconnect to MCP server {self._server_name}..."
|
||||
)
|
||||
|
||||
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
||||
if self.exit_stack:
|
||||
self._old_exit_stacks.append(self.exit_stack)
|
||||
|
||||
# Mark old session as invalid
|
||||
self.session = None
|
||||
|
||||
# Create new exit stack for new connection
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
# Reconnect using stored config
|
||||
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
||||
await self.list_tools_and_save()
|
||||
|
||||
logger.info(
|
||||
f"Successfully reconnected to MCP server {self._server_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._reconnecting = False
|
||||
|
||||
async def call_tool_with_reconnect(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict,
|
||||
read_timeout_seconds: timedelta,
|
||||
) -> mcp.types.CallToolResult:
|
||||
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
||||
|
||||
Args:
|
||||
tool_name: tool name
|
||||
arguments: tool arguments
|
||||
read_timeout_seconds: read timeout
|
||||
|
||||
Returns:
|
||||
MCP tool call result
|
||||
|
||||
Raises:
|
||||
ValueError: MCP session is not available
|
||||
anyio.ClosedResourceError: raised after reconnection failure
|
||||
"""
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_with_retry():
|
||||
if not self.session:
|
||||
raise ValueError("MCP session is not available for MCP function tools.")
|
||||
|
||||
try:
|
||||
return await self.session.call_tool(
|
||||
name=tool_name,
|
||||
arguments=arguments,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
logger.warning(
|
||||
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
||||
)
|
||||
# Attempt to reconnect
|
||||
await self._reconnect()
|
||||
# Reraise the exception to trigger tenacity retry
|
||||
raise
|
||||
|
||||
return await _call_with_retry()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
# Close current exit stack
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing current exit stack: {e}")
|
||||
|
||||
# Don't close old exit stacks as they may be in different task contexts
|
||||
# They will be garbage collected naturally
|
||||
# Just clear the list to release references
|
||||
self._old_exit_stacks.clear()
|
||||
|
||||
|
||||
class MCPTool(FunctionTool, Generic[TContext]):
|
||||
"""A function tool that calls an MCP service."""
|
||||
|
||||
def __init__(
|
||||
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
|
||||
):
|
||||
super().__init__(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
parameters=mcp_tool.inputSchema,
|
||||
)
|
||||
self.mcp_tool = mcp_tool
|
||||
self.mcp_client = mcp_client
|
||||
self.mcp_server_name = mcp_server_name
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[TContext], **kwargs
|
||||
) -> mcp.types.CallToolResult:
|
||||
return await self.mcp_client.call_tool_with_reconnect(
|
||||
tool_name=self.mcp_tool.name,
|
||||
arguments=kwargs,
|
||||
read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
|
||||
)
|
||||
|
||||
175
astrbot/core/agent/message.py
Normal file
175
astrbot/core/agent/message.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
|
||||
# License: Apache License 2.0
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
class ContentPart(BaseModel):
|
||||
"""A part of the content in a message."""
|
||||
|
||||
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
||||
|
||||
type: str
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
|
||||
|
||||
type_value = getattr(cls, "type", None)
|
||||
if type_value is None or not isinstance(type_value, str):
|
||||
raise ValueError(invalid_subclass_error_msg)
|
||||
|
||||
cls.__content_part_registry[type_value] = cls
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
# If we're dealing with the base ContentPart class, use custom validation
|
||||
if cls.__name__ == "ContentPart":
|
||||
|
||||
def validate_content_part(value: Any) -> Any:
|
||||
# if it's already an instance of a ContentPart subclass, return it
|
||||
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
|
||||
return value
|
||||
|
||||
# if it's a dict with a type field, dispatch to the appropriate subclass
|
||||
if isinstance(value, dict) and "type" in value:
|
||||
type_value: Any | None = cast(dict[str, Any], value).get("type")
|
||||
if not isinstance(type_value, str):
|
||||
raise ValueError(f"Cannot validate {value} as ContentPart")
|
||||
target_class = cls.__content_part_registry[type_value]
|
||||
return target_class.model_validate(value)
|
||||
|
||||
raise ValueError(f"Cannot validate {value} as ContentPart")
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_content_part)
|
||||
|
||||
# for subclasses, use the default schema
|
||||
return handler(source_type)
|
||||
|
||||
|
||||
class TextPart(ContentPart):
|
||||
"""
|
||||
>>> TextPart(text="Hello, world!").model_dump()
|
||||
{'type': 'text', 'text': 'Hello, world!'}
|
||||
"""
|
||||
|
||||
type: str = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ImageURLPart(ContentPart):
|
||||
"""
|
||||
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
||||
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
|
||||
"""
|
||||
|
||||
class ImageURL(BaseModel):
|
||||
url: str
|
||||
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
|
||||
id: str | None = None
|
||||
"""The ID of the image, to allow LLMs to distinguish different images."""
|
||||
|
||||
type: str = "image_url"
|
||||
image_url: ImageURL
|
||||
|
||||
|
||||
class AudioURLPart(ContentPart):
|
||||
"""
|
||||
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
|
||||
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
|
||||
"""
|
||||
|
||||
class AudioURL(BaseModel):
|
||||
url: str
|
||||
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
|
||||
id: str | None = None
|
||||
"""The ID of the audio, to allow LLMs to distinguish different audios."""
|
||||
|
||||
type: str = "audio_url"
|
||||
audio_url: AudioURL
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""
|
||||
A tool call requested by the assistant.
|
||||
|
||||
>>> ToolCall(
|
||||
... id="123",
|
||||
... function=ToolCall.FunctionBody(
|
||||
... name="function",
|
||||
... arguments="{}"
|
||||
... ),
|
||||
... ).model_dump()
|
||||
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
|
||||
"""
|
||||
|
||||
class FunctionBody(BaseModel):
|
||||
name: str
|
||||
arguments: str | None
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
id: str
|
||||
"""The ID of the tool call."""
|
||||
function: FunctionBody
|
||||
"""The function body of the tool call."""
|
||||
extra_content: dict[str, Any] | None = None
|
||||
"""Extra metadata for the tool call."""
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
if self.extra_content is None:
|
||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||
return super().model_dump(**kwargs)
|
||||
|
||||
|
||||
class ToolCallPart(BaseModel):
|
||||
"""A part of the tool call."""
|
||||
|
||||
arguments_part: str | None = None
|
||||
"""A part of the arguments of the tool call."""
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A message in a conversation."""
|
||||
|
||||
role: Literal[
|
||||
"system",
|
||||
"user",
|
||||
"assistant",
|
||||
"tool",
|
||||
]
|
||||
|
||||
content: str | list[ContentPart]
|
||||
"""The content of the message."""
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
tool_calls: list[ToolCall] | list[dict] | None = None
|
||||
|
||||
|
||||
class ToolCallMessageSegment(Message):
|
||||
"""A message segment representing a tool call."""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class UserMessageSegment(Message):
|
||||
"""A message segment from the user."""
|
||||
|
||||
role: Literal["user"] = "user"
|
||||
|
||||
|
||||
class SystemMessageSegment(Message):
|
||||
"""A message segment from the system."""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .message import Message
|
||||
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
context: TContext
|
||||
event: AstrMessageEvent
|
||||
messages: list[Message] = Field(default_factory=list)
|
||||
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
|
||||
|
||||
NoContext = ContextWrapper[None]
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import abc
|
||||
import typing as T
|
||||
from enum import Enum, auto
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponse
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..response import AgentResponse
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Defines the state of the agent."""
|
||||
@@ -28,31 +30,33 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the agent to its initial state.
|
||||
"""Reset the agent to its initial state.
|
||||
This method should be called before starting a new run.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
"""
|
||||
"""Process a single step of the agent."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process steps until the agent is done."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
Check if the agent has completed its task.
|
||||
"""Check if the agent has completed its task.
|
||||
Returns True if the agent is done, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
"""
|
||||
Get the final observation from the agent.
|
||||
"""Get the final observation from the agent.
|
||||
This method should be called after the agent is done.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -1,31 +1,33 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentState
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponseData
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from mcp.types import (
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ToolCallMessageSegment,
|
||||
AssistantMessageSegment,
|
||||
ProviderRequest,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
)
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -53,6 +55,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
messages = []
|
||||
# append existing messages in the run context
|
||||
for msg in request.contexts:
|
||||
messages.append(Message.model_validate(msg))
|
||||
if request.prompt is not None:
|
||||
m = await request.assemble_context()
|
||||
messages.append(Message.model_validate(m))
|
||||
if request.system_prompt:
|
||||
messages.insert(
|
||||
0,
|
||||
Message(role="system", content=request.system_prompt),
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
@@ -70,8 +86,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
"""Process a single step of the agent.
|
||||
This method should return the result of the step.
|
||||
"""
|
||||
if not self.req:
|
||||
@@ -95,11 +110,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
elif llm_response.completion_text:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text)
|
||||
chain=MessageChain().message(llm_response.completion_text),
|
||||
),
|
||||
)
|
||||
elif llm_response.reasoning_content:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="reasoning").message(
|
||||
llm_response.reasoning_content,
|
||||
),
|
||||
),
|
||||
)
|
||||
continue
|
||||
@@ -120,8 +144,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
|
||||
)
|
||||
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -129,6 +153,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
),
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -144,7 +175,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_resp.completion_text)
|
||||
chain=MessageChain().message(llm_resp.completion_text),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -155,13 +186,16 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
@@ -169,14 +203,28 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 将结果添加到上下文中
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
role="assistant",
|
||||
tool_calls=llm_resp.to_openai_tool_calls(),
|
||||
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||
content=llm_resp.completion_text,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
# record the assistant message with tool calls
|
||||
self.run_context.messages.extend(
|
||||
tool_calls_result.to_openai_messages_model()
|
||||
)
|
||||
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process steps until the agent is done."""
|
||||
step_count = 0
|
||||
while not self.done() and step_count < max_step:
|
||||
step_count += 1
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -205,7 +253,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: 未找到工具 {func_tool_name}",
|
||||
)
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -214,7 +262,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 获取实际的 handler 函数
|
||||
if func_tool.handler:
|
||||
logger.debug(
|
||||
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}"
|
||||
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}",
|
||||
)
|
||||
if func_tool.parameters and func_tool.parameters.get("properties"):
|
||||
expected_params = set(func_tool.parameters["properties"].keys())
|
||||
@@ -227,20 +275,21 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
# 记录被忽略的参数
|
||||
ignored_params = set(func_tool_args.keys()) - set(
|
||||
valid_params.keys()
|
||||
valid_params.keys(),
|
||||
)
|
||||
if ignored_params:
|
||||
logger.warning(
|
||||
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}"
|
||||
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}",
|
||||
)
|
||||
else:
|
||||
# 如果没有 handler(如 MCP 工具),使用所有参数
|
||||
valid_params = func_tool_args
|
||||
logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, valid_params
|
||||
self.run_context,
|
||||
func_tool,
|
||||
valid_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
||||
@@ -262,7 +311,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
@@ -271,10 +320,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
),
|
||||
)
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
res.content[0].data,
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
@@ -284,7 +333,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
@@ -297,10 +346,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
),
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result"
|
||||
type="tool_direct_result",
|
||||
).base64_image(resource.blob)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
@@ -308,41 +357,41 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
),
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||
logger.warning(
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.run_context.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool, func_tool_args, _final_resp
|
||||
self.run_context,
|
||||
func_tool,
|
||||
func_tool_args,
|
||||
_final_resp,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||
|
||||
self.run_context.event.clear_result()
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
)
|
||||
content=f"error: {e!s}",
|
||||
),
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
|
||||
@@ -1,58 +1,77 @@
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Generic
|
||||
|
||||
import jsonschema
|
||||
import mcp
|
||||
from deprecated import deprecated
|
||||
from typing import Awaitable, Callable, Literal, Any, Optional
|
||||
from .mcp_client import MCPClient
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
ToolExecResult = str | mcp.types.CallToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool:
|
||||
"""A class representing a function tool that can be used in function calling."""
|
||||
class ToolSchema:
|
||||
"""A class representing the schema of a tool for function calling."""
|
||||
|
||||
name: str
|
||||
parameters: dict | None = None
|
||||
description: str | None = None
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str | None = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
"""The name of the tool."""
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
description: str
|
||||
"""The description of the tool."""
|
||||
|
||||
parameters: ParametersType
|
||||
"""The parameters of the tool, in JSON Schema format."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parameters(self) -> "ToolSchema":
|
||||
jsonschema.validate(
|
||||
self.parameters, jsonschema.Draft202012Validator.META_SCHEMA
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
"""
|
||||
The module path of the handler function. This is empty when the origin is mcp.
|
||||
This field must be retained, as the handler will be wrapped in functools.partial during initialization,
|
||||
causing the handler's __module__ to be functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str | None = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient | None = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
"""
|
||||
Whether the tool is active. This field is a special field for AstrBot.
|
||||
You can ignore it when integrating with other frameworks.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||
|
||||
def __dict__(self) -> dict[str, Any]:
|
||||
"""将 FunctionTool 转换为字典格式"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"parameters": self.parameters,
|
||||
"description": self.description,
|
||||
"active": self.active,
|
||||
"origin": self.origin,
|
||||
"mcp_server_name": self.mcp_server_name,
|
||||
}
|
||||
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
|
||||
"""Run the tool with the given arguments. The handler field has priority."""
|
||||
raise NotImplementedError(
|
||||
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSet:
|
||||
"""A set of function tools that can be used in function calling.
|
||||
|
||||
This class provides methods to add, remove, and retrieve tools, as well as
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
|
||||
"""
|
||||
|
||||
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
tools: list[FunctionTool] = Field(default_factory=list)
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
@@ -71,7 +90,7 @@ class ToolSet:
|
||||
"""Remove a tool by its name."""
|
||||
self.tools = [tool for tool in self.tools if tool.name != name]
|
||||
|
||||
def get_tool(self, name: str) -> Optional[FunctionTool]:
|
||||
def get_tool(self, name: str) -> FunctionTool | None:
|
||||
"""Get a tool by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.name == name:
|
||||
@@ -132,10 +151,8 @@ class ToolSet:
|
||||
}
|
||||
|
||||
if (
|
||||
tool.parameters
|
||||
and tool.parameters.get("properties")
|
||||
or not omit_empty_parameter_field
|
||||
):
|
||||
tool.parameters and tool.parameters.get("properties")
|
||||
) or not omit_empty_parameter_field:
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
@@ -185,7 +202,8 @@ class ToolSet:
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
result["type"],
|
||||
set(),
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
@@ -222,7 +240,7 @@ class ToolSet:
|
||||
|
||||
tools = []
|
||||
for tool in self.tools:
|
||||
d = {
|
||||
d: dict[str, Any] = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Generic
|
||||
|
||||
import mcp
|
||||
from typing import Any, Generic, AsyncGenerator
|
||||
from .run_context import TContext, ContextWrapper
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
class BaseFunctionToolExecutor(Generic[TContext]):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[TContext],
|
||||
**tool_args,
|
||||
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class AstrAgentContext:
|
||||
provider: Provider
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
"""The message event associated with the agent context."""
|
||||
extra: dict[str, str] = Field(default_factory=dict)
|
||||
"""Customized extra data."""
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
|
||||
36
astrbot/core/astr_agent_hooks.py
Normal file
36
astrbot/core/astr_agent_hooks.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.pipeline.context_utils import call_event_hook
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
llm_response,
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
tool: FunctionTool[Any],
|
||||
tool_args: dict | None,
|
||||
tool_result: CallToolResult | None,
|
||||
):
|
||||
run_context.context.event.clear_result()
|
||||
|
||||
|
||||
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
pass
|
||||
|
||||
|
||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||
80
astrbot/core/astr_agent_run_util.py
Normal file
80
astrbot/core/astr_agent_run_util.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
stream_to_general: bool = False,
|
||||
show_reasoning: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use:
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
continue
|
||||
|
||||
if stream_to_general or not agent_runner.streaming:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
astr_event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
),
|
||||
)
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
elif resp.type == "streaming_delta":
|
||||
chain = resp.data["chain"]
|
||||
if chain.type == "reasoning" and not show_reasoning:
|
||||
# display the reasoning content only when configured
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
246
astrbot/core/astr_agent_tool_exec.py
Normal file
246
astrbot/core/astr_agent_tool_exec.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
import mcp
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import (
|
||||
CommandResult,
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
)
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
async def execute(cls, tool, run_context, **tool_args):
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
|
||||
"""
|
||||
if isinstance(tool, HandoffTool):
|
||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
elif isinstance(tool, MCPTool):
|
||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
else:
|
||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
tool: HandoffTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input")
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
ctx = run_context.context.context
|
||||
event = run_context.context.event
|
||||
umo = event.unified_msg_origin
|
||||
prov_id = await ctx.get_current_chat_provider_id(umo)
|
||||
llm_resp = await ctx.tool_loop_agent(
|
||||
event=event,
|
||||
chat_provider_id=prov_id,
|
||||
prompt=input_,
|
||||
system_prompt=tool.agent.instructions,
|
||||
tools=toolset,
|
||||
max_steps=30,
|
||||
run_hooks=tool.agent.run_hooks,
|
||||
)
|
||||
yield mcp.types.CallToolResult(
|
||||
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
event = run_context.context.event
|
||||
if not event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
|
||||
is_override_call = False
|
||||
for ty in type(tool).mro():
|
||||
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
||||
is_override_call = True
|
||||
break
|
||||
|
||||
# 检查 tool 下有没有 run 方法
|
||||
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
awaitable = None
|
||||
method_name = ""
|
||||
if tool.handler:
|
||||
awaitable = tool.handler
|
||||
method_name = "decorator_handler"
|
||||
elif is_override_call:
|
||||
awaitable = tool.call
|
||||
method_name = "call"
|
||||
elif hasattr(tool, "run"):
|
||||
awaitable = getattr(tool, "run")
|
||||
method_name = "run"
|
||||
if awaitable is None:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
wrapper = call_local_llm_tool(
|
||||
context=run_context,
|
||||
handler=awaitable,
|
||||
method_name=method_name,
|
||||
**tool_args,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
anext(wrapper),
|
||||
timeout=run_context.tool_call_timeout,
|
||||
)
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
if res := run_context.context.event.get_result():
|
||||
if res.chain:
|
||||
try:
|
||||
await event.send(
|
||||
MessageChain(
|
||||
chain=res.chain,
|
||||
type="tool_direct_result",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Tool 直接发送消息失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield None
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
res = await tool.call(run_context, **tool_args)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
event = context.context.event
|
||||
|
||||
try:
|
||||
if method_name == "run" or method_name == "decorator_handler":
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
elif method_name == "call":
|
||||
ready_to_call = handler(context, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
@@ -1,13 +1,14 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import TypedDict, TypeVar
|
||||
|
||||
from astrbot.core import AstrBotConfig, logger
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
||||
from typing import TypeVar, TypedDict
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
@@ -48,7 +49,10 @@ class AstrBotConfigManager:
|
||||
"""获取所有的 abconf 数据"""
|
||||
if self.abconf_data is None:
|
||||
self.abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
"abconf_mapping",
|
||||
{},
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
return self.abconf_data
|
||||
|
||||
@@ -64,7 +68,7 @@ class AstrBotConfigManager:
|
||||
self.confs[uuid_] = conf
|
||||
else:
|
||||
logger.warning(
|
||||
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
|
||||
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping.",
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -73,6 +77,7 @@ class AstrBotConfigManager:
|
||||
|
||||
Returns:
|
||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||
|
||||
"""
|
||||
# uuid -> { "path": str, "name": str }
|
||||
abconf_data = self._get_abconf_data()
|
||||
@@ -103,7 +108,10 @@ class AstrBotConfigManager:
|
||||
) -> None:
|
||||
"""保存配置文件的映射关系"""
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
"abconf_mapping",
|
||||
{},
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||
abconf_data[abconf_id] = {
|
||||
@@ -177,13 +185,17 @@ class AstrBotConfigManager:
|
||||
|
||||
Raises:
|
||||
ValueError: 如果试图删除默认配置文件
|
||||
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能删除默认配置文件")
|
||||
|
||||
# 从映射中移除
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
"abconf_mapping",
|
||||
{},
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
@@ -191,7 +203,8 @@ class AstrBotConfigManager:
|
||||
|
||||
# 获取配置文件路径
|
||||
conf_path = os.path.join(
|
||||
get_astrbot_config_path(), abconf_data[conf_id]["path"]
|
||||
get_astrbot_config_path(),
|
||||
abconf_data[conf_id]["path"],
|
||||
)
|
||||
|
||||
# 删除配置文件
|
||||
@@ -224,12 +237,16 @@ class AstrBotConfigManager:
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能更新默认配置文件的信息")
|
||||
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
"abconf_mapping",
|
||||
{},
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
@@ -246,7 +263,10 @@ class AstrBotConfigManager:
|
||||
return True
|
||||
|
||||
def g(
|
||||
self, umo: str | None = None, key: str | None = None, default: _VT = None
|
||||
self,
|
||||
umo: str | None = None,
|
||||
key: str | None = None,
|
||||
default: _VT = None,
|
||||
) -> _VT:
|
||||
"""获取配置项。umo 为 None 时使用默认配置"""
|
||||
if umo is None:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
|
||||
from .astrbot_config import *
|
||||
from .default import DB_PATH, DEFAULT_CONFIG, VERSION
|
||||
|
||||
__all__ = [
|
||||
"DB_PATH",
|
||||
"DEFAULT_CONFIG",
|
||||
"VERSION",
|
||||
"DB_PATH",
|
||||
"AstrBotConfig",
|
||||
]
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
|
||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
@@ -27,7 +28,7 @@ class AstrBotConfig(dict):
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
default_config: dict = DEFAULT_CONFIG,
|
||||
schema: dict = None,
|
||||
schema: dict | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -45,7 +46,7 @@ class AstrBotConfig(dict):
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
|
||||
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
with open(config_path, encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
@@ -65,7 +66,7 @@ class AstrBotConfig(dict):
|
||||
for k, v in schema.items():
|
||||
if v["type"] not in DEFAULT_VALUE_MAP:
|
||||
raise TypeError(
|
||||
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}"
|
||||
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}",
|
||||
)
|
||||
if "default" in v:
|
||||
default = v["default"]
|
||||
@@ -82,7 +83,7 @@ class AstrBotConfig(dict):
|
||||
|
||||
return conf
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
def check_config_integrity(self, refer_conf: dict, conf: dict, path=""):
|
||||
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
|
||||
has_new = False
|
||||
|
||||
@@ -97,27 +98,28 @@ class AstrBotConfig(dict):
|
||||
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
if conf[key] is None:
|
||||
# 配置项为 None,使用默认值
|
||||
elif conf[key] is None:
|
||||
# 配置项为 None,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
elif isinstance(value, dict):
|
||||
# 递归检查子配置项
|
||||
if not isinstance(conf[key], dict):
|
||||
# 类型不匹配,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
elif isinstance(value, dict):
|
||||
# 递归检查子配置项
|
||||
if not isinstance(conf[key], dict):
|
||||
# 类型不匹配,使用默认值
|
||||
new_conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
# 递归检查并同步顺序
|
||||
child_has_new = self.check_config_integrity(
|
||||
value, conf[key], path + "." + key if path else key
|
||||
)
|
||||
new_conf[key] = conf[key]
|
||||
has_new |= child_has_new
|
||||
else:
|
||||
# 直接使用现有配置
|
||||
# 递归检查并同步顺序
|
||||
child_has_new = self.check_config_integrity(
|
||||
value,
|
||||
conf[key],
|
||||
path + "." + key if path else key,
|
||||
)
|
||||
new_conf[key] = conf[key]
|
||||
has_new |= child_has_new
|
||||
else:
|
||||
# 直接使用现有配置
|
||||
new_conf[key] = conf[key]
|
||||
|
||||
# 检查是否存在参考配置中没有的配置项
|
||||
for key in list(conf.keys()):
|
||||
@@ -140,7 +142,7 @@ class AstrBotConfig(dict):
|
||||
|
||||
return has_new
|
||||
|
||||
def save_config(self, replace_config: Dict = None):
|
||||
def save_config(self, replace_config: dict | None = None):
|
||||
"""将配置写入文件
|
||||
|
||||
如果传入 replace_config,则将配置替换为 replace_config
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
|
||||
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.5.1"
|
||||
VERSION = "4.6.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -70,7 +68,7 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
},
|
||||
@@ -139,6 +137,7 @@ DEFAULT_CONFIG = {
|
||||
"kb_names": [], # 默认知识库名称列表
|
||||
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||
"kb_agentic_mode": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -742,6 +741,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
@@ -757,6 +757,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -770,6 +771,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"xai_native_search": False,
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
@@ -801,6 +803,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -815,6 +818,7 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -831,6 +835,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -872,6 +877,24 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
@@ -885,6 +908,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -901,6 +925,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -917,6 +942,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
@@ -932,6 +958,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
@@ -946,6 +973,7 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -959,6 +987,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -974,6 +1003,8 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Dify": {
|
||||
@@ -1030,6 +1061,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -1042,6 +1074,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
"timeout": 60,
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"Whisper(API)": {
|
||||
@@ -1323,6 +1356,12 @@ CONFIG_METADATA_2 = {
|
||||
"render_type": "checkbox",
|
||||
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
||||
},
|
||||
"custom_headers": {
|
||||
"description": "自定义添加请求头",
|
||||
"type": "dict",
|
||||
"items": {},
|
||||
"hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。",
|
||||
},
|
||||
"custom_extra_body": {
|
||||
"description": "自定义请求体参数",
|
||||
"type": "dict",
|
||||
@@ -1972,8 +2011,8 @@ CONFIG_METADATA_2 = {
|
||||
"show_tool_use_status": {
|
||||
"type": "bool",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"type": "bool",
|
||||
"unsupported_streaming_strategy": {
|
||||
"type": "string",
|
||||
},
|
||||
"max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
@@ -2108,6 +2147,7 @@ CONFIG_METADATA_2 = {
|
||||
"kb_names": {"type": "list", "items": {"type": "string"}},
|
||||
"kb_fusion_top_k": {"type": "int", "default": 20},
|
||||
"kb_final_top_k": {"type": "int", "default": 5},
|
||||
"kb_agentic_mode": {"type": "bool"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -2203,6 +2243,11 @@ CONFIG_METADATA_3 = {
|
||||
"type": "int",
|
||||
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
|
||||
},
|
||||
"kb_agentic_mode": {
|
||||
"description": "Agentic 知识库检索",
|
||||
"type": "bool",
|
||||
"hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"websearch": {
|
||||
@@ -2278,9 +2323,15 @@ CONFIG_METADATA_3 = {
|
||||
"description": "流式回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.streaming_segmented": {
|
||||
"description": "不支持流式回复的平台采取分段输出",
|
||||
"type": "bool",
|
||||
"provider_settings.unsupported_streaming_strategy": {
|
||||
"description": "不支持流式回复的平台",
|
||||
"type": "string",
|
||||
"options": ["realtime_segmenting", "turn_off"],
|
||||
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
|
||||
"labels": ["实时分段回复", "关闭流式回复"],
|
||||
"condition": {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
@@ -2707,9 +2758,9 @@ CONFIG_METADATA_3_SYSTEM = {
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""
|
||||
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
|
||||
"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库.
|
||||
|
||||
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
||||
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List, Callable, Awaitable
|
||||
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation, ConversationV2
|
||||
|
||||
@@ -16,31 +17,34 @@ class ConversationManager:
|
||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
self.session_conversations: Dict[str, str] = {}
|
||||
self.session_conversations: dict[str, str] = {}
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
|
||||
# 会话删除回调函数列表(用于级联清理,如知识库配置)
|
||||
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
|
||||
self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []
|
||||
|
||||
def register_on_session_deleted(
|
||||
self, callback: Callable[[str], Awaitable[None]]
|
||||
self,
|
||||
callback: Callable[[str], Awaitable[None]],
|
||||
) -> None:
|
||||
"""注册会话删除回调函数
|
||||
"""注册会话删除回调函数.
|
||||
|
||||
其他模块可以注册回调来响应会话删除事件,实现级联清理。
|
||||
例如:知识库模块可以注册回调来清理会话的知识库配置。
|
||||
|
||||
Args:
|
||||
callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数
|
||||
|
||||
"""
|
||||
self._on_session_deleted_callbacks.append(callback)
|
||||
|
||||
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
|
||||
"""触发会话删除回调
|
||||
"""触发会话删除回调.
|
||||
|
||||
Args:
|
||||
unified_msg_origin: 会话ID
|
||||
|
||||
"""
|
||||
for callback in self._on_session_deleted_callbacks:
|
||||
try:
|
||||
@@ -49,7 +53,7 @@ class ConversationManager:
|
||||
from astrbot.core import logger
|
||||
|
||||
logger.error(
|
||||
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
|
||||
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
|
||||
)
|
||||
|
||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||
@@ -75,12 +79,13 @@ class ConversationManager:
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
) -> str:
|
||||
"""新建对话,并将当前会话的对话转移到新对话
|
||||
"""新建对话,并将当前会话的对话转移到新对话.
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
|
||||
"""
|
||||
if not platform_id:
|
||||
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
||||
@@ -106,18 +111,22 @@ class ConversationManager:
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
|
||||
"""
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
||||
|
||||
async def delete_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str | None = None
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
@@ -133,6 +142,7 @@ class ConversationManager:
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
|
||||
"""
|
||||
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
@@ -148,6 +158,7 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
|
||||
"""
|
||||
ret = self.session_conversations.get(unified_msg_origin, None)
|
||||
if not ret:
|
||||
@@ -162,13 +173,15 @@ class ConversationManager:
|
||||
conversation_id: str,
|
||||
create_if_not_exists: bool = False,
|
||||
) -> Conversation | None:
|
||||
"""获取会话的对话
|
||||
"""获取会话的对话.
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
|
||||
Returns:
|
||||
conversation (Conversation): 对话对象
|
||||
|
||||
"""
|
||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||
if not conv and create_if_not_exists:
|
||||
@@ -181,18 +194,22 @@ class ConversationManager:
|
||||
return conv_res
|
||||
|
||||
async def get_conversations(
|
||||
self, unified_msg_origin: str | None = None, platform_id: str | None = None
|
||||
) -> List[Conversation]:
|
||||
"""获取对话列表
|
||||
self,
|
||||
unified_msg_origin: str | None = None,
|
||||
platform_id: str | None = None,
|
||||
) -> list[Conversation]:
|
||||
"""获取对话列表.
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
||||
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
||||
Returns:
|
||||
conversations (List[Conversation]): 对话对象列表
|
||||
|
||||
"""
|
||||
convs = await self.db.get_conversations(
|
||||
user_id=unified_msg_origin, platform_id=platform_id
|
||||
user_id=unified_msg_origin,
|
||||
platform_id=platform_id,
|
||||
)
|
||||
convs_res = []
|
||||
for conv in convs:
|
||||
@@ -208,7 +225,7 @@ class ConversationManager:
|
||||
search_query: str = "",
|
||||
**kwargs,
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""获取过滤后的对话列表
|
||||
"""获取过滤后的对话列表.
|
||||
|
||||
Args:
|
||||
page (int): 页码, 默认为 1
|
||||
@@ -217,6 +234,7 @@ class ConversationManager:
|
||||
search_query (str): 搜索查询字符串, 可选
|
||||
Returns:
|
||||
conversations (list[Conversation]): 对话对象列表
|
||||
|
||||
"""
|
||||
convs, cnt = await self.db.get_filtered_conversations(
|
||||
page=page,
|
||||
@@ -238,13 +256,14 @@ class ConversationManager:
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话
|
||||
) -> None:
|
||||
"""更新会话的对话.
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
# 如果没有提供 conversation_id,则获取当前的
|
||||
@@ -258,16 +277,20 @@ class ConversationManager:
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
|
||||
):
|
||||
"""更新会话的对话标题
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
title: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> None:
|
||||
"""更新会话的对话标题.
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
title (str): 对话标题
|
||||
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
Deprecated:
|
||||
Use `update_conversation` with `title` parameter instead.
|
||||
|
||||
"""
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
@@ -280,15 +303,16 @@ class ConversationManager:
|
||||
unified_msg_origin: str,
|
||||
persona_id: str,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话 Persona ID
|
||||
) -> None:
|
||||
"""更新会话的对话 Persona ID.
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
persona_id (str): 对话 Persona ID
|
||||
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
Deprecated:
|
||||
Use `update_conversation` with `persona_id` parameter instead.
|
||||
|
||||
"""
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
@@ -296,40 +320,85 @@ class ConversationManager:
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
async def add_message_pair(
|
||||
self,
|
||||
cid: str,
|
||||
user_message: UserMessageSegment | dict,
|
||||
assistant_message: AssistantMessageSegment | dict,
|
||||
) -> None:
|
||||
"""Add a user-assistant message pair to the conversation history.
|
||||
|
||||
Args:
|
||||
cid (str): Conversation ID
|
||||
user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
|
||||
assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict
|
||||
|
||||
Raises:
|
||||
Exception: If the conversation with the given ID is not found
|
||||
"""
|
||||
conv = await self.db.get_conversation_by_id(cid=cid)
|
||||
if not conv:
|
||||
raise Exception(f"Conversation with id {cid} not found")
|
||||
history = conv.content or []
|
||||
if isinstance(user_message, UserMessageSegment):
|
||||
user_msg_dict = user_message.model_dump()
|
||||
else:
|
||||
user_msg_dict = user_message
|
||||
if isinstance(assistant_message, AssistantMessageSegment):
|
||||
assistant_msg_dict = assistant_message.model_dump()
|
||||
else:
|
||||
assistant_msg_dict = assistant_message
|
||||
history.append(user_msg_dict)
|
||||
history.append(assistant_msg_dict)
|
||||
await self.db.update_conversation(
|
||||
cid=cid,
|
||||
content=history,
|
||||
)
|
||||
|
||||
async def get_human_readable_context(
|
||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||
):
|
||||
"""获取人类可读的上下文
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
) -> tuple[list[str], int]:
|
||||
"""获取人类可读的上下文.
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
page (int): 页码
|
||||
page_size (int): 每页大小
|
||||
|
||||
"""
|
||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||
if not conversation:
|
||||
return [], 0
|
||||
history = json.loads(conversation.history)
|
||||
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
# contexts_groups 存放按顺序的段落(每个段落是一个 str 列表),
|
||||
# 之后会被展平成一个扁平的 str 列表返回。
|
||||
contexts_groups: list[list[str]] = []
|
||||
temp_contexts: list[str] = []
|
||||
for record in history:
|
||||
if record["role"] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record["role"] == "assistant":
|
||||
if "content" in record and record["content"]:
|
||||
if record.get("content"):
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
elif "tool_calls" in record:
|
||||
tool_calls_str = json.dumps(
|
||||
record["tool_calls"], ensure_ascii=False
|
||||
record["tool_calls"],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||||
else:
|
||||
temp_contexts.append("Assistant: [未知的内容]")
|
||||
contexts.insert(0, temp_contexts)
|
||||
contexts_groups.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
# 展平分组后的 contexts 列表为单层字符串列表
|
||||
contexts = [item for sublist in contexts_groups for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
"""Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
|
||||
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
|
||||
@@ -9,44 +9,46 @@ Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、
|
||||
3. 执行启动完成事件钩子
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import threading
|
||||
import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config, html_renderer
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger, sp
|
||||
|
||||
from astrbot.core import LogBroker, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.memory.memory_manager import MemoryManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
from .event_bus import EventBus
|
||||
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
"""
|
||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
"""AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作.
|
||||
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
EventBus 等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
"""
|
||||
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None:
|
||||
self.log_broker = log_broker # 初始化日志代理
|
||||
self.astrbot_config = astrbot_config # 初始化配置
|
||||
self.db = db # 初始化数据库
|
||||
@@ -70,11 +72,11 @@ class AstrBotCoreLifecycle:
|
||||
del os.environ["no_proxy"]
|
||||
logger.debug("HTTP proxy cleared")
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
async def initialize(self) -> None:
|
||||
"""初始化 AstrBot 核心生命周期管理类.
|
||||
|
||||
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
# 初始化日志代理
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
if os.environ.get("TESTING", ""):
|
||||
@@ -91,7 +93,9 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
|
||||
default_config=self.astrbot_config,
|
||||
ucr=self.umop_config_router,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
@@ -101,6 +105,13 @@ class AstrBotCoreLifecycle:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for webchat session
|
||||
try:
|
||||
await migrate_webchat_session(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
@@ -110,7 +121,9 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化供应商管理器
|
||||
self.provider_manager = ProviderManager(
|
||||
self.astrbot_config_mgr, self.db, self.persona_mgr
|
||||
self.astrbot_config_mgr,
|
||||
self.db,
|
||||
self.persona_mgr,
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
@@ -124,6 +137,8 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
# 初始化记忆管理器
|
||||
self.memory_manager = MemoryManager()
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
@@ -137,6 +152,7 @@ class AstrBotCoreLifecycle:
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
self.kb_manager,
|
||||
self.memory_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -158,7 +174,9 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(
|
||||
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
|
||||
self.event_queue,
|
||||
self.pipeline_scheduler_mapping,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# 记录启动时间
|
||||
@@ -173,13 +191,13 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
def _load(self):
|
||||
"""加载事件总线和任务并初始化"""
|
||||
|
||||
def _load(self) -> None:
|
||||
"""加载事件总线和任务并初始化."""
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
||||
event_bus_task = asyncio.create_task(
|
||||
self.event_bus.dispatch(), name="event_bus"
|
||||
self.event_bus.dispatch(),
|
||||
name="event_bus",
|
||||
)
|
||||
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
@@ -190,16 +208,17 @@ class AstrBotCoreLifecycle:
|
||||
tasks_ = [event_bus_task, *extra_tasks]
|
||||
for task in tasks_:
|
||||
self.curr_tasks.append(
|
||||
asyncio.create_task(self._task_wrapper(task), name=task.get_name())
|
||||
asyncio.create_task(self._task_wrapper(task), name=task.get_name()),
|
||||
)
|
||||
|
||||
self.start_time = int(time.time())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
|
||||
async def _task_wrapper(self, task: asyncio.Task) -> None:
|
||||
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常.
|
||||
|
||||
Args:
|
||||
task (asyncio.Task): 要执行的异步任务
|
||||
|
||||
"""
|
||||
try:
|
||||
await task
|
||||
@@ -212,19 +231,22 @@ class AstrBotCoreLifecycle:
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
async def start(self):
|
||||
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
|
||||
async def start(self) -> None:
|
||||
"""启动 AstrBot 核心生命周期管理类.
|
||||
|
||||
用load加载事件总线和任务并初始化, 执行启动完成事件钩子
|
||||
"""
|
||||
self._load()
|
||||
logger.info("AstrBot 启动完成。")
|
||||
|
||||
# 执行启动完成事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAstrBotLoadedEvent
|
||||
EventType.OnAstrBotLoadedEvent,
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.info(
|
||||
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||
)
|
||||
await handler.handler()
|
||||
except BaseException:
|
||||
@@ -233,8 +255,8 @@ class AstrBotCoreLifecycle:
|
||||
# 同时运行curr_tasks中的所有任务
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
async def stop(self):
|
||||
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
|
||||
async def stop(self) -> None:
|
||||
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
|
||||
# 请求停止所有正在运行的异步任务
|
||||
for task in self.curr_tasks:
|
||||
task.cancel()
|
||||
@@ -245,7 +267,7 @@ class AstrBotCoreLifecycle:
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
|
||||
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。",
|
||||
)
|
||||
|
||||
await self.provider_manager.terminate()
|
||||
@@ -262,14 +284,16 @@ class AstrBotCoreLifecycle:
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||
|
||||
async def restart(self):
|
||||
async def restart(self) -> None:
|
||||
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
await self.kb_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
threading.Thread(
|
||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||
target=self.astrbot_updator._reboot,
|
||||
name="restart",
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
def load_platform(self) -> list[asyncio.Task]:
|
||||
@@ -281,36 +305,38 @@ class AstrBotCoreLifecycle:
|
||||
asyncio.create_task(
|
||||
platform_inst.run(),
|
||||
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
|
||||
)
|
||||
),
|
||||
)
|
||||
return tasks
|
||||
|
||||
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
|
||||
"""加载消息事件流水线调度器
|
||||
"""加载消息事件流水线调度器.
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
|
||||
"""
|
||||
mapping = {}
|
||||
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id),
|
||||
)
|
||||
await scheduler.initialize()
|
||||
mapping[conf_id] = scheduler
|
||||
return mapping
|
||||
|
||||
async def reload_pipeline_scheduler(self, conf_id: str):
|
||||
"""重新加载消息事件流水线调度器
|
||||
async def reload_pipeline_scheduler(self, conf_id: str) -> None:
|
||||
"""重新加载消息事件流水线调度器.
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
|
||||
"""
|
||||
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
|
||||
if not ab_config:
|
||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id),
|
||||
)
|
||||
await scheduler.initialize()
|
||||
self.pipeline_scheduler_mapping[conf_id] = scheduler
|
||||
|
||||
@@ -1,27 +1,28 @@
|
||||
import abc
|
||||
import datetime
|
||||
import typing as T
|
||||
from deprecated import deprecated
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.db.po import (
|
||||
Stats,
|
||||
PlatformStat,
|
||||
ConversationV2,
|
||||
PlatformMessageHistory,
|
||||
Attachment,
|
||||
Persona,
|
||||
Preference,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
Stats,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
"""
|
||||
数据库基类
|
||||
"""
|
||||
"""数据库基类"""
|
||||
|
||||
DATABASE_URL = ""
|
||||
|
||||
@@ -32,12 +33,13 @@ class BaseDatabase(abc.ABC):
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化数据库连接"""
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
||||
@@ -91,7 +93,9 @@ class BaseDatabase(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_conversations(
|
||||
self, user_id: str | None = None, platform_id: str | None = None
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
platform_id: str | None = None,
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations for a specific user and platform_id(optional).
|
||||
|
||||
@@ -106,7 +110,9 @@ class BaseDatabase(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations with pagination."""
|
||||
...
|
||||
@@ -173,9 +179,12 @@ class BaseDatabase(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_platform_message_offset(
|
||||
self, platform_id: str, user_id: str, offset_sec: int = 86400
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
offset_sec: int = 86400,
|
||||
) -> None:
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
"""Delete platform message history records newer than the specified offset."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -243,7 +252,11 @@ class BaseDatabase(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_preference_or_update(
|
||||
self, scope: str, scope_id: str, key: str, value: dict
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
key: str,
|
||||
value: dict,
|
||||
) -> Preference:
|
||||
"""Insert a new preference record."""
|
||||
...
|
||||
@@ -255,7 +268,10 @@ class BaseDatabase(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preferences(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str | None = None,
|
||||
key: str | None = None,
|
||||
) -> list[Preference]:
|
||||
"""Get all preferences for a specific scope ID or key."""
|
||||
...
|
||||
@@ -298,3 +314,51 @@ class BaseDatabase(abc.ABC):
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
|
||||
...
|
||||
|
||||
# ====
|
||||
# Platform Session Management
|
||||
# ====
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_platform_session(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str = "webchat",
|
||||
session_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
is_group: int = 0,
|
||||
) -> PlatformSession:
|
||||
"""Create a new Platform session."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_session_by_id(
|
||||
self, session_id: str
|
||||
) -> PlatformSession | None:
|
||||
"""Get a Platform session by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_sessions_by_creator(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformSession]:
|
||||
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_platform_session(
|
||||
self,
|
||||
session_id: str,
|
||||
display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_platform_session(self, session_id: str) -> None:
|
||||
"""Delete a Platform session by its ID."""
|
||||
...
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .migra_3_to_4 import (
|
||||
migration_conversation_table,
|
||||
migration_platform_table,
|
||||
migration_webchat_data,
|
||||
migration_persona_data,
|
||||
migration_platform_table,
|
||||
migration_preferences,
|
||||
migration_webchat_data,
|
||||
)
|
||||
|
||||
|
||||
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
"""
|
||||
检查是否需要进行数据库迁移
|
||||
"""检查是否需要进行数据库迁移
|
||||
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。
|
||||
"""
|
||||
# 仅当 data 目录下存在旧版本数据(data_v3.db 文件)时才考虑迁移
|
||||
@@ -24,7 +25,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
if not os.path.exists(data_v3_db):
|
||||
return False
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_v4"
|
||||
"global",
|
||||
"global",
|
||||
"migration_done_v4",
|
||||
)
|
||||
if migration_done:
|
||||
return False
|
||||
@@ -36,8 +39,7 @@ async def do_migration_v4(
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
astrbot_config: AstrBotConfig,
|
||||
) -> None:
|
||||
"""
|
||||
执行数据库迁移
|
||||
"""执行数据库迁移
|
||||
迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
迁移旧的 platform 到新的 platform_stats 表。
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
import json
|
||||
import datetime
|
||||
from .. import BaseDatabase
|
||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||
from .shared_preferences_v3 import sp as sp_v3
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
import json
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
|
||||
from sqlalchemy import text
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
|
||||
from .. import BaseDatabase
|
||||
from .shared_preferences_v3 import sp as sp_v3
|
||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||
|
||||
"""
|
||||
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
@@ -18,7 +21,8 @@ from sqlalchemy import text
|
||||
|
||||
|
||||
def get_platform_id(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
old_platform_name: str,
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
@@ -27,7 +31,8 @@ def get_platform_id(
|
||||
|
||||
|
||||
def get_platform_type(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
old_platform_name: str,
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
@@ -36,13 +41,15 @@ def get_platform_type(
|
||||
|
||||
|
||||
async def migration_conversation_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
page=1,
|
||||
page_size=10000000,
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
|
||||
|
||||
@@ -61,13 +68,14 @@ async def migration_conversation_table(
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, session.platform_name
|
||||
platform_id_map,
|
||||
session.platform_name,
|
||||
)
|
||||
session.platform_id = platform_id # 更新平台名称为新的 ID
|
||||
conv_v2 = ConversationV2(
|
||||
@@ -90,10 +98,11 @@ async def migration_conversation_table(
|
||||
|
||||
|
||||
async def migration_platform_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||
)
|
||||
secs_from_2023_4_10_to_now = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
@@ -134,10 +143,12 @@ async def migration_platform_table(
|
||||
if cnt == 0:
|
||||
continue
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
platform_id_map,
|
||||
platform_stats_v3[idx].name,
|
||||
)
|
||||
platform_type = get_platform_type(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
platform_id_map,
|
||||
platform_stats_v3[idx].name,
|
||||
)
|
||||
try:
|
||||
await dbsession.execute(
|
||||
@@ -149,7 +160,8 @@ async def migration_platform_table(
|
||||
"""),
|
||||
{
|
||||
"timestamp": datetime.datetime.fromtimestamp(
|
||||
bucket_end, tz=datetime.timezone.utc
|
||||
bucket_end,
|
||||
tz=datetime.timezone.utc,
|
||||
),
|
||||
"platform_id": platform_id,
|
||||
"platform_type": platform_type,
|
||||
@@ -165,14 +177,16 @@ async def migration_platform_table(
|
||||
|
||||
|
||||
async def migration_webchat_data(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
page=1,
|
||||
page_size=10000000,
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
|
||||
|
||||
@@ -191,7 +205,7 @@ async def migration_webchat_data(
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
@@ -218,10 +232,10 @@ async def migration_webchat_data(
|
||||
|
||||
|
||||
async def migration_persona_data(
|
||||
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
|
||||
db_helper: BaseDatabase,
|
||||
astrbot_config: AstrBotConfig,
|
||||
):
|
||||
"""
|
||||
迁移 Persona 数据到新的表中。
|
||||
"""迁移 Persona 数据到新的表中。
|
||||
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
||||
"""
|
||||
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
|
||||
@@ -236,14 +250,15 @@ async def migration_persona_data(
|
||||
try:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
mood_prompt = ""
|
||||
parts = []
|
||||
user_turn = True
|
||||
for mood_dialog in mood_imitation_dialogs:
|
||||
if user_turn:
|
||||
mood_prompt += f"A: {mood_dialog}\n"
|
||||
parts.append(f"A: {mood_dialog}\n")
|
||||
else:
|
||||
mood_prompt += f"B: {mood_dialog}\n"
|
||||
parts.append(f"B: {mood_dialog}\n")
|
||||
user_turn = not user_turn
|
||||
mood_prompt = "".join(parts)
|
||||
system_prompt = persona.get("prompt", "")
|
||||
if mood_prompt:
|
||||
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
|
||||
@@ -253,14 +268,15 @@ async def migration_persona_data(
|
||||
begin_dialogs=begin_dialogs,
|
||||
)
|
||||
logger.info(
|
||||
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
|
||||
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
|
||||
async def migration_preferences(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
# 1. global scope migration
|
||||
keys = [
|
||||
@@ -329,10 +345,13 @@ async def migration_preferences(
|
||||
|
||||
for provider_type, provider_id in perf.items():
|
||||
await sp.put_async(
|
||||
"umo", str(session), f"provider_perf_{provider_type}", provider_id
|
||||
"umo",
|
||||
str(session),
|
||||
f"provider_perf_{provider_type}",
|
||||
provider_id,
|
||||
)
|
||||
logger.info(
|
||||
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
|
||||
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
|
||||
|
||||
@@ -9,7 +9,7 @@ async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
|
||||
if not isinstance(abconf_data, dict):
|
||||
# should be unreachable
|
||||
logger.warning(
|
||||
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
|
||||
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}",
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
131
astrbot/core/db/migration/migra_webchat_session.py
Normal file
131
astrbot/core/db/migration/migra_webchat_session.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Migration script for WebChat sessions.
|
||||
|
||||
This migration creates PlatformSession from existing platform_message_history records.
|
||||
|
||||
Changes:
|
||||
- Creates platform_sessions table
|
||||
- Adds platform_id field (default: 'webchat')
|
||||
- Adds display_name field
|
||||
- Session_id format: {platform_id}_{uuid}
|
||||
"""
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlmodel import col
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
|
||||
|
||||
|
||||
async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
"""Create PlatformSession records from platform_message_history.
|
||||
|
||||
This migration extracts all unique user_ids from platform_message_history
|
||||
where platform_id='webchat' and creates corresponding PlatformSession records.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_webchat_session"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(WebChat 会话迁移)...")
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 从 platform_message_history 创建 PlatformSession
|
||||
query = (
|
||||
select(
|
||||
col(PlatformMessageHistory.user_id),
|
||||
col(PlatformMessageHistory.sender_name),
|
||||
func.min(PlatformMessageHistory.created_at).label("earliest"),
|
||||
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
||||
)
|
||||
.where(col(PlatformMessageHistory.platform_id) == "webchat")
|
||||
.where(col(PlatformMessageHistory.sender_id) == "astrbot")
|
||||
.group_by(col(PlatformMessageHistory.user_id))
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
webchat_users = result.all()
|
||||
|
||||
if not webchat_users:
|
||||
logger.info("没有找到需要迁移的 WebChat 数据")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_webchat_session", True
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移")
|
||||
|
||||
# 检查已存在的会话
|
||||
existing_query = select(col(PlatformSession.session_id))
|
||||
existing_result = await session.execute(existing_query)
|
||||
existing_session_ids = {row[0] for row in existing_result.fetchall()}
|
||||
|
||||
# 查询 Conversations 表中的 title,用于设置 display_name
|
||||
# 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id}
|
||||
user_ids_to_query = [
|
||||
f"webchat:FriendMessage:webchat!astrbot!{user_id}"
|
||||
for user_id, _, _, _ in webchat_users
|
||||
]
|
||||
conv_query = select(
|
||||
col(ConversationV2.user_id), col(ConversationV2.title)
|
||||
).where(col(ConversationV2.user_id).in_(user_ids_to_query))
|
||||
conv_result = await session.execute(conv_query)
|
||||
# 创建 user_id -> title 的映射字典
|
||||
title_map = {
|
||||
user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title
|
||||
for user_id, title in conv_result.fetchall()
|
||||
}
|
||||
|
||||
# 批量创建 PlatformSession 记录
|
||||
sessions_to_add = []
|
||||
skipped_count = 0
|
||||
|
||||
for user_id, sender_name, created_at, updated_at in webchat_users:
|
||||
# user_id 就是 webchat_conv_id (session_id)
|
||||
session_id = user_id
|
||||
|
||||
# sender_name 通常是 username,但可能为 None
|
||||
creator = sender_name if sender_name else "guest"
|
||||
|
||||
# 检查是否已经存在该会话
|
||||
if session_id in existing_session_ids:
|
||||
logger.debug(f"会话 {session_id} 已存在,跳过")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# 从 Conversations 表中获取 display_name
|
||||
display_name = title_map.get(user_id)
|
||||
|
||||
# 创建新的 PlatformSession(保留原有的时间戳)
|
||||
new_session = PlatformSession(
|
||||
session_id=session_id,
|
||||
platform_id="webchat",
|
||||
creator=creator,
|
||||
is_group=0,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
display_name=display_name,
|
||||
)
|
||||
sessions_to_add.append(new_session)
|
||||
|
||||
# 批量插入
|
||||
if sessions_to_add:
|
||||
session.add_all(sessions_to_add)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}",
|
||||
)
|
||||
else:
|
||||
logger.info("没有新会话需要迁移")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_webchat_session", True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
@@ -16,7 +17,7 @@ class SharedPreferences:
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
with open(self.path) as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import sqlite3
|
||||
import time
|
||||
from astrbot.core.db.po import Platform, Stats
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.db.po import Platform, Stats
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,7 +95,7 @@ class SQLiteDatabase:
|
||||
c.execute(
|
||||
"""
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
@@ -108,14 +109,14 @@ class SQLiteDatabase:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
"""
|
||||
""",
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
"""
|
||||
""",
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
@@ -126,7 +127,7 @@ class SQLiteDatabase:
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: Tuple = None):
|
||||
def _exec_sql(self, sql: str, params: tuple = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -174,7 +175,7 @@ class SQLiteDatabase:
|
||||
"""
|
||||
SELECT * FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
+ where_clause,
|
||||
)
|
||||
|
||||
platform = []
|
||||
@@ -194,7 +195,7 @@ class SQLiteDatabase:
|
||||
c.execute(
|
||||
"""
|
||||
SELECT SUM(count) FROM platform
|
||||
"""
|
||||
""",
|
||||
)
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
@@ -214,7 +215,7 @@ class SQLiteDatabase:
|
||||
SELECT name, SUM(count), timestamp FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
+ " GROUP BY name"
|
||||
+ " GROUP BY name",
|
||||
)
|
||||
|
||||
platform = []
|
||||
@@ -242,7 +243,7 @@ class SQLiteDatabase:
|
||||
c.close()
|
||||
|
||||
if not res:
|
||||
return
|
||||
return None
|
||||
|
||||
return Conversation(*res)
|
||||
|
||||
@@ -257,7 +258,7 @@ class SQLiteDatabase:
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
def get_conversations(self, user_id: str) -> tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -280,7 +281,7 @@ class SQLiteDatabase:
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(
|
||||
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
||||
Conversation("", cid, "[]", created_at, updated_at, title, persona_id),
|
||||
)
|
||||
return conversations
|
||||
|
||||
@@ -319,8 +320,10 @@ class SQLiteDatabase:
|
||||
)
|
||||
|
||||
def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -366,7 +369,7 @@ class SQLiteDatabase:
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
@@ -381,12 +384,12 @@ class SQLiteDatabase:
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platforms: List[str] = None,
|
||||
message_types: List[str] = None,
|
||||
search_query: str = None,
|
||||
exclude_ids: List[str] = None,
|
||||
exclude_platforms: List[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
platforms: list[str] | None = None,
|
||||
message_types: list[str] | None = None,
|
||||
search_query: str | None = None,
|
||||
exclude_ids: list[str] | None = None,
|
||||
exclude_platforms: list[str] | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""获取筛选后的对话列表"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -422,7 +425,7 @@ class SQLiteDatabase:
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
where_clauses.append(
|
||||
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
||||
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
|
||||
)
|
||||
search_param = f"%{search_query}%"
|
||||
params.extend([search_param, search_param, search_param, search_param])
|
||||
@@ -482,7 +485,7 @@ class SQLiteDatabase:
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
import uuid
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from sqlmodel import (
|
||||
SQLModel,
|
||||
Text,
|
||||
JSON,
|
||||
UniqueConstraint,
|
||||
Field,
|
||||
)
|
||||
from typing import Optional, TypedDict
|
||||
from datetime import datetime, timezone
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
|
||||
|
||||
|
||||
class PlatformStat(SQLModel, table=True):
|
||||
@@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats"
|
||||
__tablename__ = "platform_stats" # type: ignore
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
@@ -37,10 +31,11 @@ class PlatformStat(SQLModel, table=True):
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations"
|
||||
__tablename__ = "conversations" # type: ignore
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
conversation_id: str = Field(
|
||||
max_length=36,
|
||||
@@ -50,14 +45,14 @@ class ConversationV2(SQLModel, table=True):
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False)
|
||||
content: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
content: list | None = Field(default=None, sa_type=JSON)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
title: Optional[str] = Field(default=None, max_length=255)
|
||||
persona_id: Optional[str] = Field(default=None)
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
persona_id: str | None = Field(default=None)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -73,16 +68,18 @@ class Persona(SQLModel, table=True):
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas"
|
||||
__tablename__ = "personas" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
persona_id: str = Field(max_length=255, nullable=False)
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
begin_dialogs: list | None = Field(default=None, sa_type=JSON)
|
||||
"""a list of strings, each representing a dialog to start with"""
|
||||
tools: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
tools: list | None = Field(default=None, sa_type=JSON)
|
||||
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
@@ -101,10 +98,12 @@ class Persona(SQLModel, table=True):
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences"
|
||||
__tablename__ = "preferences" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
default=None,
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
scope: str = Field(nullable=False)
|
||||
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
||||
@@ -135,16 +134,18 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
__tablename__ = "platform_message_history" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
sender_name: Optional[str] = Field(
|
||||
default=None
|
||||
sender_id: str | None = Field(default=None) # ID of the sender in the platform
|
||||
sender_name: str | None = Field(
|
||||
default=None,
|
||||
) # Name of the sender in the platform
|
||||
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
@@ -154,16 +155,60 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class PlatformSession(SQLModel, table=True):
|
||||
"""Platform session table for managing user sessions across different platforms.
|
||||
|
||||
A session represents a chat window for a specific user on a specific platform.
|
||||
Each session can have multiple conversations (对话) associated with it.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_sessions" # type: ignore
|
||||
|
||||
inner_id: int | None = Field(
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
session_id: str = Field(
|
||||
max_length=100,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: f"webchat_{uuid.uuid4()}",
|
||||
)
|
||||
platform_id: str = Field(default="webchat", nullable=False)
|
||||
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
|
||||
creator: str = Field(nullable=False)
|
||||
"""Username of the session creator"""
|
||||
display_name: str | None = Field(default=None, max_length=255)
|
||||
"""Display name for the session"""
|
||||
is_group: int = Field(default=0, nullable=False)
|
||||
"""0 for private chat, 1 for group chat (not implemented yet)"""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"session_id",
|
||||
name="uix_platform_session_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Attachment(SQLModel, table=True):
|
||||
"""This class represents attachments for messages in AstrBot.
|
||||
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments"
|
||||
__tablename__ = "attachments" # type: ignore
|
||||
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
import asyncio
|
||||
import typing as T
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
import typing as T
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import (
|
||||
ConversationV2,
|
||||
PlatformStat,
|
||||
PlatformMessageHistory,
|
||||
Attachment,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
Stats as DeprecatedStats,
|
||||
Platform as DeprecatedPlatformStat,
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
from sqlmodel import select, update, delete, text, func, or_, desc, col
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from astrbot.core.db.po import (
|
||||
Platform as DeprecatedPlatformStat,
|
||||
)
|
||||
from astrbot.core.db.po import (
|
||||
Stats as DeprecatedStats,
|
||||
)
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
@@ -57,7 +63,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now().replace(
|
||||
minute=0, second=0, microsecond=0
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
current_hour = timestamp
|
||||
await session.execute(
|
||||
@@ -81,13 +89,13 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
||||
PlatformStat
|
||||
)
|
||||
PlatformStat,
|
||||
),
|
||||
)
|
||||
count = result.scalar_one_or_none()
|
||||
return count if count is not None else 0
|
||||
|
||||
async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]:
|
||||
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
@@ -138,7 +146,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
select(ConversationV2)
|
||||
.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.limit(page_size),
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@@ -157,7 +165,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
if platform_ids:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.platform_id).in_(platform_ids)
|
||||
col(ConversationV2.platform_id).in_(platform_ids),
|
||||
)
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
@@ -167,16 +175,16 @@ class SQLiteDatabase(BaseDatabase):
|
||||
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
||||
)
|
||||
),
|
||||
)
|
||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||
for msg_type in kwargs["message_types"]:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
|
||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
|
||||
)
|
||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||
base_query = base_query.where(
|
||||
col(ConversationV2.platform_id).in_(kwargs["platforms"])
|
||||
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
|
||||
)
|
||||
|
||||
# Get total count matching the filters
|
||||
@@ -233,7 +241,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(ConversationV2).where(
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
col(ConversationV2.conversation_id) == cid,
|
||||
)
|
||||
values = {}
|
||||
if title is not None:
|
||||
@@ -243,7 +251,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if not values:
|
||||
return
|
||||
return None
|
||||
query = query.values(**values)
|
||||
await session.execute(query)
|
||||
return await self.get_conversation_by_id(cid)
|
||||
@@ -254,8 +262,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
)
|
||||
col(ConversationV2.conversation_id) == cid,
|
||||
),
|
||||
)
|
||||
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
@@ -263,7 +271,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.user_id) == user_id
|
||||
),
|
||||
)
|
||||
|
||||
async def get_session_conversations(
|
||||
@@ -282,7 +292,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
select(
|
||||
col(Preference.scope_id).label("session_id"),
|
||||
func.json_extract(Preference.value, "$.val").label(
|
||||
"conversation_id"
|
||||
"conversation_id",
|
||||
), # type: ignore
|
||||
col(ConversationV2.persona_id).label("persona_id"),
|
||||
col(ConversationV2.title).label("title"),
|
||||
@@ -295,7 +305,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
Persona,
|
||||
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
@@ -308,14 +319,14 @@ class SQLiteDatabase(BaseDatabase):
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# 平台筛选
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
base_query = base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
col(Preference.scope_id).like(platform_pattern),
|
||||
)
|
||||
|
||||
# 排序
|
||||
@@ -336,7 +347,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
Persona,
|
||||
col(ConversationV2.persona_id) == Persona.persona_id,
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
@@ -349,13 +361,13 @@ class SQLiteDatabase(BaseDatabase):
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
count_base_query = count_base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
col(Preference.scope_id).like(platform_pattern),
|
||||
)
|
||||
|
||||
total_result = await session.execute(count_base_query)
|
||||
@@ -396,9 +408,12 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return new_history
|
||||
|
||||
async def delete_platform_message_offset(
|
||||
self, platform_id, user_id, offset_sec=86400
|
||||
self,
|
||||
platform_id,
|
||||
user_id,
|
||||
offset_sec=86400,
|
||||
):
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
"""Delete platform message history records newer than the specified offset."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -408,12 +423,16 @@ class SQLiteDatabase(BaseDatabase):
|
||||
delete(PlatformMessageHistory).where(
|
||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||
col(PlatformMessageHistory.user_id) == user_id,
|
||||
col(PlatformMessageHistory.created_at) < cutoff_time,
|
||||
)
|
||||
col(PlatformMessageHistory.created_at) >= cutoff_time,
|
||||
),
|
||||
)
|
||||
|
||||
async def get_platform_message_history(
|
||||
self, platform_id, user_id, page=1, page_size=20
|
||||
self,
|
||||
platform_id,
|
||||
user_id,
|
||||
page=1,
|
||||
page_size=20,
|
||||
):
|
||||
"""Get platform message history records."""
|
||||
async with self.get_db() as session:
|
||||
@@ -452,7 +471,11 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def insert_persona(
|
||||
self, persona_id, system_prompt, begin_dialogs=None, tools=None
|
||||
self,
|
||||
persona_id,
|
||||
system_prompt,
|
||||
begin_dialogs=None,
|
||||
tools=None,
|
||||
):
|
||||
"""Insert a new persona record."""
|
||||
async with self.get_db() as session:
|
||||
@@ -484,7 +507,11 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_persona(
|
||||
self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
|
||||
self,
|
||||
persona_id,
|
||||
system_prompt=None,
|
||||
begin_dialogs=None,
|
||||
tools=NOT_GIVEN,
|
||||
):
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
async with self.get_db() as session:
|
||||
@@ -499,7 +526,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
if tools is not NOT_GIVEN:
|
||||
values["tools"] = tools
|
||||
if not values:
|
||||
return
|
||||
return None
|
||||
query = query.values(**values)
|
||||
await session.execute(query)
|
||||
return await self.get_persona_by_id(persona_id)
|
||||
@@ -510,7 +537,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
delete(Persona).where(col(Persona.persona_id) == persona_id),
|
||||
)
|
||||
|
||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||
@@ -529,7 +556,10 @@ class SQLiteDatabase(BaseDatabase):
|
||||
existing_preference.value = value
|
||||
else:
|
||||
new_preference = Preference(
|
||||
scope=scope, scope_id=scope_id, key=key, value=value
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
session.add(new_preference)
|
||||
return existing_preference or new_preference
|
||||
@@ -568,7 +598,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
col(Preference.key) == key,
|
||||
)
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -581,7 +611,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
delete(Preference).where(
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -598,7 +628,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(seconds=offset_sec)
|
||||
result = await session.execute(
|
||||
select(PlatformStat).where(PlatformStat.timestamp >= start_time)
|
||||
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
||||
)
|
||||
all_datas = result.scalars().all()
|
||||
deprecated_stats = DeprecatedStats()
|
||||
@@ -608,7 +638,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
name=data.platform_id,
|
||||
count=data.count,
|
||||
timestamp=int(data.timestamp.timestamp()),
|
||||
)
|
||||
),
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
@@ -630,7 +660,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat)
|
||||
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
||||
)
|
||||
total_count = result.scalar_one_or_none()
|
||||
return total_count if total_count is not None else 0
|
||||
@@ -656,7 +686,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(
|
||||
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
||||
.where(PlatformStat.timestamp >= start_time)
|
||||
.group_by(PlatformStat.platform_id)
|
||||
.group_by(PlatformStat.platform_id),
|
||||
)
|
||||
grouped_stats = result.all()
|
||||
deprecated_stats = DeprecatedStats()
|
||||
@@ -666,7 +696,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
name=platform_id,
|
||||
count=count,
|
||||
timestamp=int(start_time.timestamp()),
|
||||
)
|
||||
),
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
@@ -680,3 +710,101 @@ class SQLiteDatabase(BaseDatabase):
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
# ====
|
||||
# Platform Session Management
|
||||
# ====
|
||||
|
||||
async def create_platform_session(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str = "webchat",
|
||||
session_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
is_group: int = 0,
|
||||
) -> PlatformSession:
|
||||
"""Create a new Platform session."""
|
||||
kwargs = {}
|
||||
if session_id:
|
||||
kwargs["session_id"] = session_id
|
||||
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
new_session = PlatformSession(
|
||||
creator=creator,
|
||||
platform_id=platform_id,
|
||||
display_name=display_name,
|
||||
is_group=is_group,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(new_session)
|
||||
await session.flush()
|
||||
await session.refresh(new_session)
|
||||
return new_session
|
||||
|
||||
async def get_platform_session_by_id(
|
||||
self, session_id: str
|
||||
) -> PlatformSession | None:
|
||||
"""Get a Platform session by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(PlatformSession).where(
|
||||
PlatformSession.session_id == session_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_platform_sessions_by_creator(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformSession]:
|
||||
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
query = select(PlatformSession).where(PlatformSession.creator == creator)
|
||||
|
||||
if platform_id:
|
||||
query = query.where(PlatformSession.platform_id == platform_id)
|
||||
|
||||
query = (
|
||||
query.order_by(desc(PlatformSession.updated_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_platform_session(
|
||||
self,
|
||||
session_id: str,
|
||||
display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
|
||||
await session.execute(
|
||||
update(PlatformSession)
|
||||
.where(col(PlatformSession.session_id == session_id))
|
||||
.values(**values),
|
||||
)
|
||||
|
||||
async def delete_platform_session(self, session_id: str) -> None:
|
||||
"""Delete a Platform session by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(PlatformSession).where(
|
||||
col(PlatformSession.session_id == session_id),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
class ResultData(TypedDict):
|
||||
id: str
|
||||
doc_id: str
|
||||
text: str
|
||||
metadata: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
similarity: float
|
||||
data: dict
|
||||
data: ResultData | dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化向量数据库
|
||||
"""
|
||||
pass
|
||||
"""初始化向量数据库"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
self,
|
||||
content: str,
|
||||
metadata: dict | None = None,
|
||||
id: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -35,11 +42,11 @@ class BaseVecDB:
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> int:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -52,8 +59,7 @@ class BaseVecDB:
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
"""搜索最相似的文档。
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
top_k (int): 返回的最相似文档的数量
|
||||
@@ -64,8 +70,7 @@ class BaseVecDB:
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, doc_id: str) -> bool:
|
||||
"""
|
||||
删除指定文档。
|
||||
"""删除指定文档。
|
||||
Args:
|
||||
doc_id (str): 要删除的文档 ID
|
||||
Returns:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Text, Column
|
||||
from sqlalchemy import Column, Text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
|
||||
from sqlmodel import Field, MetaData, SQLModel, col, func, select, text
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@@ -20,7 +21,9 @@ class Document(BaseDocModel, table=True):
|
||||
__tablename__ = "documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
default=None,
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
doc_id: str = Field(nullable=False)
|
||||
text: str = Field(nullable=False)
|
||||
@@ -36,7 +39,8 @@ class DocumentStorage:
|
||||
self.engine: AsyncEngine | None = None
|
||||
self.async_session_maker: sessionmaker | None = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
os.path.dirname(__file__),
|
||||
"sqlite_init.sql",
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
@@ -50,26 +54,26 @@ class DocumentStorage:
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
|
||||
)
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED",
|
||||
),
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN user_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
|
||||
)
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED",
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
|
||||
)
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)",
|
||||
),
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
|
||||
)
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)",
|
||||
),
|
||||
)
|
||||
except BaseException:
|
||||
pass
|
||||
@@ -113,10 +117,11 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
list: The list of documents that match the filters.
|
||||
|
||||
"""
|
||||
if self.engine is None:
|
||||
logger.warning(
|
||||
"Database connection is not initialized, returning empty result"
|
||||
"Database connection is not initialized, returning empty result",
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -125,7 +130,7 @@ class DocumentStorage:
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
if ids is not None and len(ids) > 0:
|
||||
@@ -153,24 +158,27 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
int: The integer ID of the inserted document.
|
||||
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
session.add(document)
|
||||
await session.flush() # Flush to get the ID
|
||||
return document.id # type: ignore
|
||||
async with self.get_session() as session, session.begin():
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
session.add(document)
|
||||
await session.flush() # Flush to get the ID
|
||||
return document.id # type: ignore
|
||||
|
||||
async def insert_documents_batch(
|
||||
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
|
||||
self,
|
||||
doc_ids: list[str],
|
||||
texts: list[str],
|
||||
metadatas: list[dict],
|
||||
) -> list[int]:
|
||||
"""Batch insert documents and return their integer IDs.
|
||||
|
||||
@@ -181,44 +189,44 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
list[int]: List of integer IDs of the inserted documents.
|
||||
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
import json
|
||||
async with self.get_session() as session, session.begin():
|
||||
import json
|
||||
|
||||
documents = []
|
||||
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
documents = []
|
||||
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
|
||||
await session.flush() # Flush to get all IDs
|
||||
return [doc.id for doc in documents] # type: ignore
|
||||
await session.flush() # Flush to get all IDs
|
||||
return [doc.id for doc in documents] # type: ignore
|
||||
|
||||
async def delete_document_by_doc_id(self, doc_id: str):
|
||||
"""Delete a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to delete.
|
||||
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
async with self.get_session() as session, session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
await session.delete(document)
|
||||
if document:
|
||||
await session.delete(document)
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
@@ -228,6 +236,7 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
dict: The document data or None if not found.
|
||||
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
@@ -246,46 +255,46 @@ class DocumentStorage:
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
async with self.get_session() as session, session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
document.text = new_text
|
||||
document.updated_at = datetime.now()
|
||||
session.add(document)
|
||||
if document:
|
||||
document.text = new_text
|
||||
document.updated_at = datetime.now()
|
||||
session.add(document)
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
"""Delete documents by their metadata filters.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
|
||||
"""
|
||||
if self.engine is None:
|
||||
logger.warning(
|
||||
"Database connection is not initialized, skipping delete operation"
|
||||
"Database connection is not initialized, skipping delete operation",
|
||||
)
|
||||
return
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document)
|
||||
async with self.get_session() as session, session.begin():
|
||||
query = select(Document)
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
for doc in documents:
|
||||
await session.delete(doc)
|
||||
for doc in documents:
|
||||
await session.delete(doc)
|
||||
|
||||
async def count_documents(self, metadata_filters: dict | None = None) -> int:
|
||||
"""Count documents in the database.
|
||||
@@ -295,6 +304,7 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
int: The count of documents.
|
||||
|
||||
"""
|
||||
if self.engine is None:
|
||||
logger.warning("Database connection is not initialized, returning 0")
|
||||
@@ -306,7 +316,7 @@ class DocumentStorage:
|
||||
if metadata_filters:
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
@@ -318,12 +328,13 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = text(
|
||||
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
|
||||
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL",
|
||||
)
|
||||
result = await session.execute(query)
|
||||
rows = result.fetchall()
|
||||
@@ -337,6 +348,7 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
|
||||
"""
|
||||
return {
|
||||
"id": document.id,
|
||||
@@ -361,6 +373,7 @@ class DocumentStorage:
|
||||
dict: The converted dictionary.
|
||||
|
||||
Note: This method is kept for backward compatibility but is no longer used internally.
|
||||
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
|
||||
@@ -2,9 +2,10 @@ try:
|
||||
import faiss
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。"
|
||||
"faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。",
|
||||
)
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -27,11 +28,12 @@ class EmbeddingStorage:
|
||||
id (int): 向量的ID
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
await self.save_index()
|
||||
@@ -44,11 +46,12 @@ class EmbeddingStorage:
|
||||
ids (list[int]): 向量的ID列表
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
if vectors.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
|
||||
)
|
||||
self.index.add_with_ids(vectors, np.array(ids))
|
||||
await self.save_index()
|
||||
@@ -61,6 +64,7 @@ class EmbeddingStorage:
|
||||
k (int): 返回的最相似向量的数量
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
faiss.normalize_L2(vector)
|
||||
@@ -72,6 +76,7 @@ class EmbeddingStorage:
|
||||
|
||||
Args:
|
||||
ids (list[int]): 要删除的向量ID列表
|
||||
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
id_array = np.array(ids, dtype=np.int64)
|
||||
@@ -83,5 +88,6 @@ class EmbeddingStorage:
|
||||
|
||||
Args:
|
||||
path (str): 保存索引的路径
|
||||
|
||||
"""
|
||||
faiss.write_index(self.index, self.path)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import uuid
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||
|
||||
from ..base import BaseVecDB, Result
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
"""
|
||||
A class to represent a vector database.
|
||||
"""
|
||||
"""A class to represent a vector database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -26,7 +26,8 @@ class FaissVecDB(BaseVecDB):
|
||||
self.embedding_provider = embedding_provider
|
||||
self.document_storage = DocumentStorage(doc_store_path)
|
||||
self.embedding_storage = EmbeddingStorage(
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
embedding_provider.get_dim(),
|
||||
index_store_path,
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
self.rerank_provider = rerank_provider
|
||||
@@ -35,11 +36,12 @@ class FaissVecDB(BaseVecDB):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
self,
|
||||
content: str,
|
||||
metadata: dict | None = None,
|
||||
id: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
|
||||
metadata = metadata or {}
|
||||
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
|
||||
|
||||
@@ -63,11 +65,11 @@ class FaissVecDB(BaseVecDB):
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
|
||||
"""
|
||||
metadatas = metadatas or [{} for _ in contents]
|
||||
ids = ids or [str(uuid.uuid4()) for _ in contents]
|
||||
@@ -83,12 +85,14 @@ class FaissVecDB(BaseVecDB):
|
||||
)
|
||||
end = time.time()
|
||||
logger.debug(
|
||||
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
|
||||
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.",
|
||||
)
|
||||
|
||||
# 使用 DocumentStorage 的批量插入方法
|
||||
int_ids = await self.document_storage.insert_documents_batch(
|
||||
ids, contents, metadatas
|
||||
ids,
|
||||
contents,
|
||||
metadatas,
|
||||
)
|
||||
|
||||
# 批量插入向量到 FAISS
|
||||
@@ -104,8 +108,7 @@ class FaissVecDB(BaseVecDB):
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
"""搜索最相似的文档。
|
||||
|
||||
Args:
|
||||
query (str): 查询文本
|
||||
@@ -116,6 +119,7 @@ class FaissVecDB(BaseVecDB):
|
||||
|
||||
Returns:
|
||||
List[Result]: 查询结果
|
||||
|
||||
"""
|
||||
embedding = await self.embedding_provider.get_embedding(query)
|
||||
scores, indices = await self.embedding_storage.search(
|
||||
@@ -128,7 +132,8 @@ class FaissVecDB(BaseVecDB):
|
||||
scores[0] = 1.0 - (scores[0] / 2.0)
|
||||
# NOTE: maybe the size is less than k.
|
||||
fetched_docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters or {}, ids=indices[0]
|
||||
metadata_filters=metadata_filters or {},
|
||||
ids=indices[0],
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
@@ -149,7 +154,9 @@ class FaissVecDB(BaseVecDB):
|
||||
documents = [doc.data["text"] for doc in top_k_results]
|
||||
reranked_results = await self.rerank_provider.rerank(query, documents)
|
||||
reranked_results = sorted(
|
||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
||||
reranked_results,
|
||||
key=lambda x: x.relevance_score,
|
||||
reverse=True,
|
||||
)
|
||||
top_k_results = [
|
||||
top_k_results[reranked_result.index]
|
||||
@@ -159,9 +166,7 @@ class FaissVecDB(BaseVecDB):
|
||||
return top_k_results
|
||||
|
||||
async def delete(self, doc_id: str):
|
||||
"""
|
||||
删除一条文档块(chunk)
|
||||
"""
|
||||
"""删除一条文档块(chunk)"""
|
||||
# 获得对应的 int id
|
||||
result = await self.document_storage.get_document_by_doc_id(doc_id)
|
||||
int_id = result["id"] if result else None
|
||||
@@ -176,23 +181,23 @@ class FaissVecDB(BaseVecDB):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self, metadata_filter: dict | None = None) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
"""计算文档数量
|
||||
|
||||
Args:
|
||||
metadata_filter (dict | None): 元数据过滤器
|
||||
|
||||
"""
|
||||
count = await self.document_storage.count_documents(
|
||||
metadata_filters=metadata_filter or {}
|
||||
metadata_filters=metadata_filter or {},
|
||||
)
|
||||
return count
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
"""
|
||||
根据元数据过滤器删除文档
|
||||
"""
|
||||
"""根据元数据过滤器删除文档"""
|
||||
docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters, offset=None, limit=None
|
||||
metadata_filters=metadata_filters,
|
||||
offset=None,
|
||||
limit=None,
|
||||
)
|
||||
doc_ids: list[int] = [doc["id"] for doc in docs]
|
||||
await self.embedding_storage.delete(doc_ids)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
事件总线, 用于处理事件的分发和处理
|
||||
"""事件总线, 用于处理事件的分发和处理
|
||||
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||
|
||||
@@ -13,10 +12,12 @@ class:
|
||||
|
||||
import asyncio
|
||||
from asyncio import Queue
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||
|
||||
from astrbot.core import logger
|
||||
from .platform import AstrMessageEvent
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||
|
||||
from .platform import AstrMessageEvent
|
||||
|
||||
|
||||
class EventBus:
|
||||
@@ -46,14 +47,15 @@ class EventBus:
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
|
||||
"""
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
if event.get_sender_name():
|
||||
logger.info(
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}",
|
||||
)
|
||||
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||
else:
|
||||
logger.info(
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}",
|
||||
)
|
||||
|
||||
9
astrbot/core/exceptions.py
Normal file
9
astrbot/core/exceptions.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AstrBotError(Exception):
|
||||
"""Base exception for all AstrBot errors."""
|
||||
|
||||
|
||||
class ProviderNotFoundError(AstrBotError):
|
||||
"""Raised when a specified provider is not found."""
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from urllib.parse import urlparse, unquote
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
@@ -40,8 +40,8 @@ class FileTokenService:
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 当路径不存在时抛出
|
||||
"""
|
||||
|
||||
"""
|
||||
# 处理 file:///
|
||||
try:
|
||||
parsed_uri = urlparse(file_path)
|
||||
@@ -61,7 +61,7 @@ class FileTokenService:
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在: {local_path} (原始输入: {file_path})"
|
||||
f"文件不存在: {local_path} (原始输入: {file_path})",
|
||||
)
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
@@ -84,6 +84,7 @@ class FileTokenService:
|
||||
Raises:
|
||||
KeyError: 当令牌不存在或已过期时抛出
|
||||
FileNotFoundError: 当文件本身已被删除时抛出
|
||||
|
||||
"""
|
||||
async with self.lock:
|
||||
await self._cleanup_expired_tokens()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||
|
||||
工作流程:
|
||||
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
|
||||
@@ -8,10 +7,10 @@ AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from astrbot.core import logger
|
||||
|
||||
from astrbot.core import LogBroker, logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
|
||||
|
||||
@@ -39,7 +38,10 @@ class InitialLoader:
|
||||
webui_dir = self.webui_dir
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
|
||||
core_lifecycle,
|
||||
self.db,
|
||||
core_lifecycle.dashboard_shutdown_event,
|
||||
webui_dir,
|
||||
)
|
||||
|
||||
coro = self.dashboard_server.run()
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
文档分块模块
|
||||
"""
|
||||
"""文档分块模块"""
|
||||
|
||||
from .base import BaseChunker
|
||||
from .fixed_size import FixedSizeChunker
|
||||
|
||||
@@ -21,4 +21,5 @@ class BaseChunker(ABC):
|
||||
|
||||
Returns:
|
||||
list[str]: 分块后的文本列表
|
||||
|
||||
"""
|
||||
|
||||
@@ -18,6 +18,7 @@ class FixedSizeChunker(BaseChunker):
|
||||
Args:
|
||||
chunk_size: 块的大小(字符数)
|
||||
chunk_overlap: 块之间的重叠字符数
|
||||
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
@@ -32,6 +33,7 @@ class FixedSizeChunker(BaseChunker):
|
||||
|
||||
Returns:
|
||||
list[str]: 分块后的文本列表
|
||||
|
||||
"""
|
||||
chunk_size = kwargs.get("chunk_size", self.chunk_size)
|
||||
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from .base import BaseChunker
|
||||
|
||||
|
||||
@@ -11,8 +12,7 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
is_separator_regex: bool = False,
|
||||
separators: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
初始化递归字符文本分割器
|
||||
"""初始化递归字符文本分割器
|
||||
|
||||
Args:
|
||||
chunk_size: 每个文本块的最大大小
|
||||
@@ -20,6 +20,7 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
length_function: 计算文本长度的函数
|
||||
is_separator_regex: 分隔符是否为正则表达式
|
||||
separators: 用于分割文本的分隔符列表,按优先级排序
|
||||
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
@@ -39,8 +40,7 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
]
|
||||
|
||||
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||
"""
|
||||
递归地将文本分割成块
|
||||
"""递归地将文本分割成块
|
||||
|
||||
Args:
|
||||
text: 要分割的文本
|
||||
@@ -49,6 +49,7 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
|
||||
Returns:
|
||||
分割后的文本块列表
|
||||
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
@@ -90,7 +91,7 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
combined_text,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=overlap,
|
||||
)
|
||||
),
|
||||
)
|
||||
current_chunk = []
|
||||
current_chunk_length = 0
|
||||
@@ -98,8 +99,10 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
# 递归分割过大的部分
|
||||
final_chunks.extend(
|
||||
await self.chunk(
|
||||
split, chunk_size=chunk_size, chunk_overlap=overlap
|
||||
)
|
||||
split,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=overlap,
|
||||
),
|
||||
)
|
||||
# 如果添加这部分会使当前块超过chunk_size
|
||||
elif current_chunk_length + split_length > chunk_size:
|
||||
@@ -132,16 +135,19 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
return [text]
|
||||
|
||||
def _split_by_character(
|
||||
self, text: str, chunk_size: int | None = None, overlap: int | None = None
|
||||
self,
|
||||
text: str,
|
||||
chunk_size: int | None = None,
|
||||
overlap: int | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
按字符级别分割文本
|
||||
"""按字符级别分割文本
|
||||
|
||||
Args:
|
||||
text: 要分割的文本
|
||||
|
||||
Returns:
|
||||
分割后的文本块列表
|
||||
|
||||
"""
|
||||
chunk_size = chunk_size or self.chunk_size
|
||||
overlap = overlap or self.chunk_overlap
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import col, desc
|
||||
from sqlalchemy import text, func, select, update, delete
|
||||
from sqlalchemy import delete, func, select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import col, desc
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
BaseKBModel,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
@@ -21,6 +21,7 @@ class KBSQLiteDatabase:
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
|
||||
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
@@ -85,77 +86,77 @@ class KBSQLiteDatabase:
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||
"ON knowledge_bases(kb_id)"
|
||||
)
|
||||
"ON knowledge_bases(kb_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||
"ON knowledge_bases(kb_name)"
|
||||
)
|
||||
"ON knowledge_bases(kb_name)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||
"ON knowledge_bases(created_at)"
|
||||
)
|
||||
"ON knowledge_bases(created_at)",
|
||||
),
|
||||
)
|
||||
|
||||
# 创建文档表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||
"ON kb_documents(doc_id)"
|
||||
)
|
||||
"ON kb_documents(doc_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||
"ON kb_documents(kb_id)"
|
||||
)
|
||||
"ON kb_documents(kb_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||
"ON kb_documents(doc_name)"
|
||||
)
|
||||
"ON kb_documents(doc_name)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||
"ON kb_documents(file_type)"
|
||||
)
|
||||
"ON kb_documents(file_type)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||
"ON kb_documents(created_at)"
|
||||
)
|
||||
"ON kb_documents(created_at)",
|
||||
),
|
||||
)
|
||||
|
||||
# 创建多媒体表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||
"ON kb_media(media_id)"
|
||||
)
|
||||
"ON kb_media(media_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||
"ON kb_media(doc_id)"
|
||||
)
|
||||
"ON kb_media(doc_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
||||
)
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||
"ON kb_media(media_type)"
|
||||
)
|
||||
"ON kb_media(media_type)",
|
||||
),
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -208,7 +209,10 @@ class KBSQLiteDatabase:
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_documents_by_kb(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
self,
|
||||
kb_id: str,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.get_db() as session:
|
||||
@@ -226,7 +230,7 @@ class KBSQLiteDatabase:
|
||||
"""统计知识库的文档数量"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(func.count(col(KBDocument.id))).where(
|
||||
col(KBDocument.kb_id) == kb_id
|
||||
col(KBDocument.kb_id) == kb_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
@@ -252,12 +256,11 @@ class KBSQLiteDatabase:
|
||||
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
|
||||
"""删除单个文档及其相关数据"""
|
||||
# 在知识库表中删除
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
# 删除文档记录
|
||||
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
await session.execute(delete_stmt)
|
||||
await session.commit()
|
||||
async with self.get_db() as session, session.begin():
|
||||
# 删除文档记录
|
||||
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
await session.execute(delete_stmt)
|
||||
await session.commit()
|
||||
|
||||
# 在 vec db 中删除相关向量
|
||||
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
|
||||
@@ -282,18 +285,17 @@ class KBSQLiteDatabase:
|
||||
"""更新知识库统计信息"""
|
||||
chunk_cnt = await vec_db.count_documents()
|
||||
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
update_stmt = (
|
||||
update(KnowledgeBase)
|
||||
.where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
.values(
|
||||
doc_count=select(func.count(col(KBDocument.id)))
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.scalar_subquery(),
|
||||
chunk_count=chunk_cnt,
|
||||
)
|
||||
async with self.get_db() as session, session.begin():
|
||||
update_stmt = (
|
||||
update(KnowledgeBase)
|
||||
.where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
.values(
|
||||
doc_count=select(func.count(col(KBDocument.id)))
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.scalar_subquery(),
|
||||
chunk_count=chunk_cnt,
|
||||
)
|
||||
)
|
||||
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
|
||||
@@ -1,16 +1,108 @@
|
||||
import uuid
|
||||
import aiofiles
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from .models import KnowledgeBase, KBDocument, KBMedia
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
import aiofiles
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from .parsers.util import select_parser
|
||||
from astrbot.core.provider.provider import (
|
||||
EmbeddingProvider,
|
||||
RerankProvider,
|
||||
)
|
||||
from astrbot.core.provider.provider import (
|
||||
Provider as LLMProvider,
|
||||
)
|
||||
|
||||
from .chunking.base import BaseChunker
|
||||
from astrbot.core import logger
|
||||
from .chunking.recursive import RecursiveCharacterChunker
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
from .models import KBDocument, KBMedia, KnowledgeBase
|
||||
from .parsers.url_parser import extract_text_from_url
|
||||
from .parsers.util import select_parser
|
||||
from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""一个简单的速率限制器"""
|
||||
|
||||
def __init__(self, max_rpm: int):
|
||||
self.max_per_minute = max_rpm
|
||||
self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
|
||||
self.last_call_time = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.interval == 0:
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
elapsed = now - self.last_call_time
|
||||
|
||||
if elapsed < self.interval:
|
||||
await asyncio.sleep(self.interval - elapsed)
|
||||
|
||||
self.last_call_time = time.monotonic()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
async def _repair_and_translate_chunk_with_retry(
|
||||
chunk: str,
|
||||
repair_llm_service: LLMProvider,
|
||||
rate_limiter: RateLimiter,
|
||||
max_retries: int = 2,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Repairs, translates, and optionally re-chunks a single text chunk using the small LLM, with rate limiting.
|
||||
"""
|
||||
# 为了防止 LLM 上下文污染,在 user_prompt 中也加入明确的指令
|
||||
user_prompt = f"""IGNORE ALL PREVIOUS INSTRUCTIONS. Your ONLY task is to process the following text chunk according to the system prompt provided.
|
||||
|
||||
Text chunk to process:
|
||||
---
|
||||
{chunk}
|
||||
---
|
||||
"""
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
async with rate_limiter:
|
||||
response = await repair_llm_service.text_chat(
|
||||
prompt=user_prompt, system_prompt=TEXT_REPAIR_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
llm_output = response.completion_text
|
||||
|
||||
if "<discard_chunk />" in llm_output:
|
||||
return [] # Signal to discard this chunk
|
||||
|
||||
# More robust regex to handle potential LLM formatting errors (spaces, newlines in tags)
|
||||
matches = re.findall(
|
||||
r"<\s*repaired_text\s*>\s*(.*?)\s*<\s*/\s*repaired_text\s*>",
|
||||
llm_output,
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if matches:
|
||||
# Further cleaning to ensure no empty strings are returned
|
||||
return [m.strip() for m in matches if m.strip()]
|
||||
else:
|
||||
# If no valid tags and not explicitly discarded, discard it to be safe.
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f" - LLM call failed on attempt {attempt + 1}/{max_retries + 1}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f" - Failed to process chunk after {max_retries + 1} attempts. Using original text."
|
||||
)
|
||||
return [chunk]
|
||||
|
||||
|
||||
class KBHelper:
|
||||
@@ -45,11 +137,11 @@ class KBHelper:
|
||||
if not self.kb.embedding_provider_id:
|
||||
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.embedding_provider_id
|
||||
self.kb.embedding_provider_id,
|
||||
) # type: ignore
|
||||
if not ep:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
|
||||
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider",
|
||||
)
|
||||
return ep
|
||||
|
||||
@@ -57,11 +149,11 @@ class KBHelper:
|
||||
if not self.kb.rerank_provider_id:
|
||||
return None
|
||||
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.rerank_provider_id
|
||||
self.kb.rerank_provider_id,
|
||||
) # type: ignore
|
||||
if not rp:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
|
||||
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider",
|
||||
)
|
||||
return rp
|
||||
|
||||
@@ -97,7 +189,7 @@ class KBHelper:
|
||||
async def upload_document(
|
||||
self,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_content: bytes | None,
|
||||
file_type: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
@@ -105,6 +197,7 @@ class KBHelper:
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
pre_chunked_text: list[str] | None = None,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
@@ -122,48 +215,68 @@ class KBHelper:
|
||||
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
|
||||
- current: 当前进度
|
||||
- total: 总数
|
||||
|
||||
"""
|
||||
await self._ensure_vec_db()
|
||||
doc_id = str(uuid.uuid4())
|
||||
media_paths: list[Path] = []
|
||||
file_size = 0
|
||||
|
||||
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
||||
# async with aiofiles.open(file_path, "wb") as f:
|
||||
# await f.write(file_content)
|
||||
|
||||
try:
|
||||
# 阶段1: 解析文档
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 0, 100)
|
||||
|
||||
parser = await select_parser(f".{file_type}")
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 100, 100)
|
||||
|
||||
# 保存媒体文件
|
||||
chunks_text = []
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await self._save_media(
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
|
||||
if pre_chunked_text is not None:
|
||||
# 如果提供了预分块文本,直接使用
|
||||
chunks_text = pre_chunked_text
|
||||
file_size = sum(len(chunk) for chunk in chunks_text)
|
||||
logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。")
|
||||
else:
|
||||
# 否则,执行标准的文件解析和分块流程
|
||||
if file_content is None:
|
||||
raise ValueError(
|
||||
"当未提供 pre_chunked_text 时,file_content 不能为空。"
|
||||
)
|
||||
|
||||
file_size = len(file_content)
|
||||
|
||||
# 阶段1: 解析文档
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 0, 100)
|
||||
|
||||
parser = await select_parser(f".{file_type}")
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 100, 100)
|
||||
|
||||
# 保存媒体文件
|
||||
for media_item in media_items:
|
||||
media = await self._save_media(
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 阶段2: 分块
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 0, 100)
|
||||
|
||||
chunks_text = await self.chunker.chunk(
|
||||
text_content,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 阶段2: 分块
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 0, 100)
|
||||
|
||||
chunks_text = await self.chunker.chunk(
|
||||
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
contents = []
|
||||
metadatas = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
@@ -173,7 +286,7 @@ class KBHelper:
|
||||
"kb_id": self.kb.kb_id,
|
||||
"kb_doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
@@ -199,7 +312,7 @@ class KBHelper:
|
||||
kb_id=self.kb.kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
file_size=file_size,
|
||||
# file_path=str(file_path),
|
||||
file_path="",
|
||||
chunk_count=len(chunks_text),
|
||||
@@ -234,7 +347,9 @@ class KBHelper:
|
||||
raise e
|
||||
|
||||
async def list_documents(
|
||||
self, offset: int = 0, limit: int = 100
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
|
||||
@@ -288,12 +403,17 @@ class KBHelper:
|
||||
await session.refresh(doc)
|
||||
|
||||
async def get_chunks_by_doc_id(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
self,
|
||||
doc_id: str,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""获取文档的所有块及其元数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
chunks = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
|
||||
metadata_filters={"kb_doc_id": doc_id},
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
result = []
|
||||
for chunk in chunks:
|
||||
@@ -306,7 +426,7 @@ class KBHelper:
|
||||
"chunk_index": chunk_md["chunk_index"],
|
||||
"content": chunk["text"],
|
||||
"char_count": len(chunk["text"]),
|
||||
}
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -346,3 +466,177 @@ class KBHelper:
|
||||
)
|
||||
|
||||
return media
|
||||
|
||||
async def upload_from_url(
|
||||
self,
|
||||
url: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
enable_cleaning: bool = False,
|
||||
cleaning_provider_id: str | None = None,
|
||||
) -> KBDocument:
|
||||
"""从 URL 上传并处理文档(带原子性保证和失败清理)
|
||||
Args:
|
||||
url: 要提取内容的网页 URL
|
||||
chunk_size: 文本块大小
|
||||
chunk_overlap: 文本块重叠大小
|
||||
batch_size: 批处理大小
|
||||
tasks_limit: 并发任务限制
|
||||
max_retries: 最大重试次数
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
||||
- stage: 当前阶段 ('extracting', 'cleaning', 'parsing', 'chunking', 'embedding')
|
||||
- current: 当前进度
|
||||
- total: 总数
|
||||
Returns:
|
||||
KBDocument: 上传的文档对象
|
||||
Raises:
|
||||
ValueError: 如果 URL 为空或无法提取内容
|
||||
IOError: 如果网络请求失败
|
||||
"""
|
||||
# 获取 Tavily API 密钥
|
||||
config = self.prov_mgr.acm.default_conf
|
||||
tavily_keys = config.get("provider_settings", {}).get(
|
||||
"websearch_tavily_key", []
|
||||
)
|
||||
if not tavily_keys:
|
||||
raise ValueError(
|
||||
"Error: Tavily API key is not configured in provider_settings."
|
||||
)
|
||||
|
||||
# 阶段1: 从 URL 提取内容
|
||||
if progress_callback:
|
||||
await progress_callback("extracting", 0, 100)
|
||||
|
||||
try:
|
||||
text_content = await extract_text_from_url(url, tavily_keys)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract content from URL {url}: {e}")
|
||||
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
|
||||
|
||||
if not text_content:
|
||||
raise ValueError(f"No content extracted from URL: {url}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("extracting", 100, 100)
|
||||
|
||||
# 阶段2: (可选)清洗内容并分块
|
||||
final_chunks = await self._clean_and_rechunk_content(
|
||||
content=text_content,
|
||||
url=url,
|
||||
progress_callback=progress_callback,
|
||||
enable_cleaning=enable_cleaning,
|
||||
cleaning_provider_id=cleaning_provider_id,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
if enable_cleaning and not final_chunks:
|
||||
raise ValueError(
|
||||
"内容清洗后未提取到有效文本。请尝试关闭内容清洗功能,或更换更高性能的LLM模型后重试。"
|
||||
)
|
||||
|
||||
# 创建一个虚拟文件名
|
||||
file_name = url.split("/")[-1] or f"document_from_{url}"
|
||||
if not Path(file_name).suffix:
|
||||
file_name += ".url"
|
||||
|
||||
# 复用现有的 upload_document 方法,但传入预分块文本
|
||||
return await self.upload_document(
|
||||
file_name=file_name,
|
||||
file_content=None,
|
||||
file_type="url", # 使用 'url' 作为特殊文件类型
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
pre_chunked_text=final_chunks,
|
||||
)
|
||||
|
||||
async def _clean_and_rechunk_content(
|
||||
self,
|
||||
content: str,
|
||||
url: str,
|
||||
progress_callback=None,
|
||||
enable_cleaning: bool = False,
|
||||
cleaning_provider_id: str | None = None,
|
||||
repair_max_rpm: int = 60,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
) -> list[str]:
|
||||
"""
|
||||
对从 URL 获取的内容进行清洗、修复、翻译和重新分块。
|
||||
"""
|
||||
if not enable_cleaning:
|
||||
# 如果不启用清洗,则使用从前端传递的参数进行分块
|
||||
logger.info(
|
||||
f"内容清洗未启用,使用指定参数进行分块: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}"
|
||||
)
|
||||
return await self.chunker.chunk(
|
||||
content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
if not cleaning_provider_id:
|
||||
logger.warning(
|
||||
"启用了内容清洗,但未提供 cleaning_provider_id,跳过清洗并使用默认分块。"
|
||||
)
|
||||
return await self.chunker.chunk(content)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("cleaning", 0, 100)
|
||||
|
||||
try:
|
||||
# 获取指定的 LLM Provider
|
||||
llm_provider = await self.prov_mgr.get_provider_by_id(cleaning_provider_id)
|
||||
if not llm_provider or not isinstance(llm_provider, LLMProvider):
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {cleaning_provider_id} 的 LLM Provider 或类型不正确"
|
||||
)
|
||||
|
||||
# 初步分块
|
||||
# 优化分隔符,优先按段落分割,以获得更高质量的文本块
|
||||
text_splitter = RecursiveCharacterChunker(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=["\n\n", "\n", " "], # 优先使用段落分隔符
|
||||
)
|
||||
initial_chunks = await text_splitter.chunk(content)
|
||||
logger.info(f"初步分块完成,生成 {len(initial_chunks)} 个块用于修复。")
|
||||
|
||||
# 并发处理所有块
|
||||
rate_limiter = RateLimiter(repair_max_rpm)
|
||||
tasks = [
|
||||
_repair_and_translate_chunk_with_retry(
|
||||
chunk, llm_provider, rate_limiter
|
||||
)
|
||||
for chunk in initial_chunks
|
||||
]
|
||||
|
||||
repaired_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
final_chunks = []
|
||||
for i, result in enumerate(repaired_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"块 {i} 处理异常: {str(result)}. 回退到原始块。")
|
||||
final_chunks.append(initial_chunks[i])
|
||||
elif isinstance(result, list):
|
||||
final_chunks.extend(result)
|
||||
|
||||
logger.info(
|
||||
f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("cleaning", 100, 100)
|
||||
|
||||
return final_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"使用 Provider '{cleaning_provider_id}' 清洗内容失败: {e}")
|
||||
# 清洗失败,返回默认分块结果,保证流程不中断
|
||||
return await self.chunker.chunk(content)
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
|
||||
from .retrieval.manager import RetrievalManager, RetrievalResult
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
# from .chunking.fixed_size import FixedSizeChunker
|
||||
from .chunking.recursive import RecursiveCharacterChunker
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
from .kb_helper import KBHelper
|
||||
|
||||
from .models import KnowledgeBase
|
||||
|
||||
from .models import KBDocument, KnowledgeBase
|
||||
from .retrieval.manager import RetrievalManager, RetrievalResult
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
|
||||
FILES_PATH = "data/knowledge_base"
|
||||
DB_PATH = Path(FILES_PATH) / "kb.db"
|
||||
@@ -257,6 +255,7 @@ class KnowledgeBaseManager:
|
||||
|
||||
Returns:
|
||||
str: 格式化的上下文文本
|
||||
|
||||
"""
|
||||
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||
|
||||
@@ -285,3 +284,47 @@ class KnowledgeBaseManager:
|
||||
await self.kb_db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭知识库元数据数据库失败: {e}")
|
||||
|
||||
async def upload_from_url(
|
||||
self,
|
||||
kb_id: str,
|
||||
url: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> KBDocument:
|
||||
"""从 URL 上传文档到指定的知识库
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
url: 要提取内容的网页 URL
|
||||
chunk_size: 文本块大小
|
||||
chunk_overlap: 文本块重叠大小
|
||||
batch_size: 批处理大小
|
||||
tasks_limit: 并发任务限制
|
||||
max_retries: 最大重试次数
|
||||
progress_callback: 进度回调函数
|
||||
|
||||
Returns:
|
||||
KBDocument: 上传的文档对象
|
||||
|
||||
Raises:
|
||||
ValueError: 如果知识库不存在或 URL 为空
|
||||
IOError: 如果网络请求失败
|
||||
"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise ValueError(f"Knowledge base with id {kb_id} not found.")
|
||||
|
||||
return await kb_helper.upload_from_url(
|
||||
url=url,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
|
||||
from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint
|
||||
|
||||
|
||||
class BaseKBModel(SQLModel, table=False):
|
||||
@@ -17,7 +17,9 @@ class KnowledgeBase(BaseKBModel, table=True):
|
||||
__tablename__ = "knowledge_bases" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
kb_id: str = Field(
|
||||
max_length=36,
|
||||
@@ -63,7 +65,9 @@ class KBDocument(BaseKBModel, table=True):
|
||||
__tablename__ = "kb_documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
doc_id: str = Field(
|
||||
max_length=36,
|
||||
@@ -95,7 +99,9 @@ class KBMedia(BaseKBModel, table=True):
|
||||
__tablename__ = "kb_media" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
media_id: str = Field(
|
||||
max_length=36,
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"""
|
||||
文档解析器模块
|
||||
"""
|
||||
"""文档解析器模块"""
|
||||
|
||||
from .base import BaseParser, MediaItem, ParseResult
|
||||
from .text_parser import TextParser
|
||||
from .pdf_parser import PDFParser
|
||||
from .text_parser import TextParser
|
||||
|
||||
__all__ = [
|
||||
"BaseParser",
|
||||
"MediaItem",
|
||||
"PDFParser",
|
||||
"ParseResult",
|
||||
"TextParser",
|
||||
"PDFParser",
|
||||
]
|
||||
|
||||
@@ -47,4 +47,5 @@ class BaseParser(ABC):
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
|
||||
"""
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import io
|
||||
import os
|
||||
|
||||
from markitdown_no_magika import MarkItDown, StreamInfo
|
||||
|
||||
from astrbot.core.knowledge_base.parsers.base import (
|
||||
BaseParser,
|
||||
ParseResult,
|
||||
)
|
||||
from markitdown_no_magika import MarkItDown, StreamInfo
|
||||
|
||||
|
||||
class MarkitdownParser(BaseParser):
|
||||
|
||||
@@ -29,6 +29,7 @@ class PDFParser(BaseParser):
|
||||
|
||||
Returns:
|
||||
ParseResult: 包含文本和图片的解析结果
|
||||
|
||||
"""
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
reader = PdfReader(pdf_file)
|
||||
@@ -87,7 +88,7 @@ class PDFParser(BaseParser):
|
||||
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
|
||||
content=image_data,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
# 单个图片提取失败不影响整体
|
||||
|
||||
@@ -26,6 +26,7 @@ class TextParser(BaseParser):
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法解码文件
|
||||
|
||||
"""
|
||||
# 尝试多种编码
|
||||
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
|
||||
|
||||
103
astrbot/core/knowledge_base/parsers/url_parser.py
Normal file
103
astrbot/core/knowledge_base/parsers/url_parser.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class URLExtractor:
|
||||
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
|
||||
|
||||
def __init__(self, tavily_keys: list[str]):
|
||||
"""
|
||||
初始化 URL 提取器
|
||||
|
||||
Args:
|
||||
tavily_keys: Tavily API 密钥列表
|
||||
"""
|
||||
if not tavily_keys:
|
||||
raise ValueError("Error: Tavily API keys are not configured.")
|
||||
|
||||
self.tavily_keys = tavily_keys
|
||||
self.tavily_key_index = 0
|
||||
self.tavily_key_lock = asyncio.Lock()
|
||||
|
||||
async def _get_tavily_key(self) -> str:
|
||||
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
|
||||
async with self.tavily_key_lock:
|
||||
key = self.tavily_keys[self.tavily_key_index]
|
||||
self.tavily_key_index = (self.tavily_key_index + 1) % len(self.tavily_keys)
|
||||
return key
|
||||
|
||||
async def extract_text_from_url(self, url: str) -> str:
|
||||
"""
|
||||
使用 Tavily API 从 URL 提取主要文本内容。
|
||||
这是 web_searcher 插件中 tavily_extract_web_page 方法的简化版本,
|
||||
专门为知识库模块设计,不依赖 AstrMessageEvent。
|
||||
|
||||
Args:
|
||||
url: 要提取内容的网页 URL
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 URL 为空或 API 密钥未配置
|
||||
IOError: 如果请求失败或返回错误
|
||||
"""
|
||||
if not url:
|
||||
raise ValueError("Error: url must be a non-empty string.")
|
||||
|
||||
tavily_key = await self._get_tavily_key()
|
||||
api_url = "https://api.tavily.com/extract"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {tavily_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"urls": [url],
|
||||
"extract_depth": "basic", # 使用基础提取深度
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0, # 增加超时时间,因为内容提取可能需要更长时间
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
reason = await response.text()
|
||||
raise OSError(
|
||||
f"Tavily web extraction failed: {reason}, status: {response.status}"
|
||||
)
|
||||
|
||||
data = await response.json()
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
raise ValueError(f"No content extracted from URL: {url}")
|
||||
|
||||
# 返回第一个结果的内容
|
||||
return results[0].get("raw_content", "")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise OSError(f"Failed to fetch URL {url}: {e}") from e
|
||||
except Exception as e:
|
||||
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
|
||||
|
||||
|
||||
# 为了向后兼容,提供一个简单的函数接口
|
||||
async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str:
|
||||
"""
|
||||
简单的函数接口,用于从 URL 提取文本内容
|
||||
|
||||
Args:
|
||||
url: 要提取内容的网页 URL
|
||||
tavily_keys: Tavily API 密钥列表
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
"""
|
||||
extractor = URLExtractor(tavily_keys)
|
||||
return await extractor.extract_text_from_url(url)
|
||||
@@ -6,7 +6,7 @@ async def select_parser(ext: str) -> BaseParser:
|
||||
from .markitdown_parser import MarkitdownParser
|
||||
|
||||
return MarkitdownParser()
|
||||
elif ext == ".pdf":
|
||||
if ext == ".pdf":
|
||||
from .pdf_parser import PDFParser
|
||||
|
||||
return PDFParser()
|
||||
|
||||
65
astrbot/core/knowledge_base/prompts.py
Normal file
65
astrbot/core/knowledge_base/prompts.py
Normal file
@@ -0,0 +1,65 @@
|
||||
TEXT_REPAIR_SYSTEM_PROMPT = """You are a meticulous digital archivist. Your mission is to reconstruct a clean, readable article from raw, noisy text chunks.
|
||||
|
||||
**Core Task:**
|
||||
1. **Analyze:** Examine the text chunk to separate "signal" (substantive information) from "noise" (UI elements, ads, navigation, footers).
|
||||
2. **Process:** Clean and repair the signal. **Do not translate it.** Keep the original language.
|
||||
|
||||
**Crucial Rules:**
|
||||
- **NEVER discard a chunk if it contains ANY valuable information.** Your primary duty is to salvage content.
|
||||
- **If a chunk contains multiple distinct topics, split them.** Enclose each topic in its own `<repaired_text>` tag.
|
||||
- Your output MUST be ONLY `<repaired_text>...</repaired_text>` tags or a single `<discard_chunk />` tag.
|
||||
|
||||
---
|
||||
**Example 1: Chunk with Noise and Signal**
|
||||
|
||||
*Input Chunk:*
|
||||
"Home | About | Products | **The Llama is a domesticated South American camelid.** | © 2025 ACME Corp."
|
||||
|
||||
*Your Thought Process:*
|
||||
1. "Home | About | Products..." and "© 2025 ACME Corp." are noise.
|
||||
2. "The Llama is a domesticated..." is the signal.
|
||||
3. I must extract the signal and wrap it.
|
||||
|
||||
*Your Output:*
|
||||
<repaired_text>
|
||||
The Llama is a domesticated South American camelid.
|
||||
</repaired_text>
|
||||
|
||||
---
|
||||
**Example 2: Chunk with ONLY Noise**
|
||||
|
||||
*Input Chunk:*
|
||||
"Next Page > | Subscribe to our newsletter | Follow us on X"
|
||||
|
||||
*Your Thought Process:*
|
||||
1. This entire chunk is noise. There is no signal.
|
||||
2. I must discard this.
|
||||
|
||||
*Your Output:*
|
||||
<discard_chunk />
|
||||
|
||||
---
|
||||
**Example 3: Chunk with Multiple Topics (Requires Splitting)**
|
||||
|
||||
*Input Chunk:*
|
||||
"## Chapter 1: The Sun
|
||||
The Sun is the star at the center of the Solar System.
|
||||
|
||||
## Chapter 2: The Moon
|
||||
The Moon is Earth's only natural satellite."
|
||||
|
||||
*Your Thought Process:*
|
||||
1. This chunk contains two distinct topics.
|
||||
2. I must process them separately to maintain semantic integrity.
|
||||
3. I will create two `<repaired_text>` blocks.
|
||||
|
||||
*Your Output:*
|
||||
<repaired_text>
|
||||
## Chapter 1: The Sun
|
||||
The Sun is the star at the center of the Solar System.
|
||||
</repaired_text>
|
||||
<repaired_text>
|
||||
## Chapter 2: The Moon
|
||||
The Moon is Earth's only natural satellite.
|
||||
</repaired_text>
|
||||
"""
|
||||
@@ -1,16 +1,14 @@
|
||||
"""
|
||||
检索模块
|
||||
"""
|
||||
"""检索模块"""
|
||||
|
||||
from .manager import RetrievalManager, RetrievalResult
|
||||
from .sparse_retriever import SparseRetriever, SparseResult
|
||||
from .rank_fusion import RankFusion, FusedResult
|
||||
from .rank_fusion import FusedResult, RankFusion
|
||||
from .sparse_retriever import SparseResult, SparseRetriever
|
||||
|
||||
__all__ = [
|
||||
"FusedResult",
|
||||
"RankFusion",
|
||||
"RetrievalManager",
|
||||
"RetrievalResult",
|
||||
"SparseRetriever",
|
||||
"SparseResult",
|
||||
"RankFusion",
|
||||
"FusedResult",
|
||||
"SparseRetriever",
|
||||
]
|
||||
|
||||
@@ -4,18 +4,17 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
from ..kb_helper import KBHelper
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,6 +52,7 @@ class RetrievalManager:
|
||||
sparse_retriever: 稀疏检索器
|
||||
rank_fusion: 结果融合器
|
||||
kb_db: 知识库数据库实例
|
||||
|
||||
"""
|
||||
self.sparse_retriever = sparse_retriever
|
||||
self.rank_fusion = rank_fusion
|
||||
@@ -61,11 +61,11 @@ class RetrievalManager:
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
kb_ids: list[str],
|
||||
kb_id_helper_map: dict[str, KBHelper],
|
||||
top_k_fusion: int = 20,
|
||||
top_m_final: int = 5,
|
||||
) -> List[RetrievalResult]:
|
||||
) -> list[RetrievalResult]:
|
||||
"""混合检索
|
||||
|
||||
流程:
|
||||
@@ -82,6 +82,7 @@ class RetrievalManager:
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 检索结果列表
|
||||
|
||||
"""
|
||||
if not kb_ids:
|
||||
return []
|
||||
@@ -114,7 +115,7 @@ class RetrievalManager:
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
|
||||
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.",
|
||||
)
|
||||
|
||||
# 2. 稀疏检索
|
||||
@@ -126,7 +127,7 @@ class RetrievalManager:
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
|
||||
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.",
|
||||
)
|
||||
|
||||
# 3. 结果融合
|
||||
@@ -138,7 +139,7 @@ class RetrievalManager:
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
|
||||
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
|
||||
)
|
||||
|
||||
# 4. 转换为 RetrievalResult (获取元数据)
|
||||
@@ -159,7 +160,7 @@ class RetrievalManager:
|
||||
"chunk_index": fr.chunk_index,
|
||||
"char_count": len(fr.content),
|
||||
},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# 5. Rerank
|
||||
@@ -188,7 +189,7 @@ class RetrievalManager:
|
||||
async def _dense_retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
kb_ids: list[str],
|
||||
kb_options: dict,
|
||||
):
|
||||
"""稠密检索 (向量相似度)
|
||||
@@ -202,6 +203,7 @@ class RetrievalManager:
|
||||
|
||||
Returns:
|
||||
List[Result]: 检索结果列表
|
||||
|
||||
"""
|
||||
all_results: list[Result] = []
|
||||
for kb_id in kb_ids:
|
||||
@@ -233,10 +235,10 @@ class RetrievalManager:
|
||||
async def _rerank(
|
||||
self,
|
||||
query: str,
|
||||
results: List[RetrievalResult],
|
||||
results: list[RetrievalResult],
|
||||
top_k: int,
|
||||
rerank_provider: RerankProvider,
|
||||
) -> List[RetrievalResult]:
|
||||
) -> list[RetrievalResult]:
|
||||
"""Rerank 重排序
|
||||
|
||||
Args:
|
||||
@@ -246,6 +248,7 @@ class RetrievalManager:
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 重排序后的结果列表
|
||||
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
@@ -37,6 +37,7 @@ class RankFusion:
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
k: RRF 参数,用于平滑排名
|
||||
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self.k = k
|
||||
@@ -59,6 +60,7 @@ class RankFusion:
|
||||
|
||||
Returns:
|
||||
List[FusedResult]: 融合后的结果列表
|
||||
|
||||
"""
|
||||
# 1. 构建排名映射
|
||||
dense_ranks = {
|
||||
@@ -101,7 +103,9 @@ class RankFusion:
|
||||
|
||||
# 4. 排序
|
||||
sorted_ids = sorted(
|
||||
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
|
||||
rrf_scores.keys(),
|
||||
key=lambda cid: rrf_scores[cid],
|
||||
reverse=True,
|
||||
)[:top_k]
|
||||
|
||||
# 5. 构建融合结果
|
||||
@@ -118,7 +122,7 @@ class RankFusion:
|
||||
kb_id=sr.kb_id,
|
||||
content=sr.content,
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
),
|
||||
)
|
||||
elif identifier in vec_doc_id_to_dense:
|
||||
# 从向量检索获取信息,需要从数据库获取块的详细信息
|
||||
@@ -132,7 +136,7 @@ class RankFusion:
|
||||
kb_id=chunk_md["kb_id"],
|
||||
content=vec_result.data["text"],
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return fused_results
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
使用 BM25 算法进行基于关键词的文档检索
|
||||
"""
|
||||
|
||||
import jieba
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import jieba
|
||||
from rank_bm25 import BM25Okapi
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -37,6 +39,7 @@ class SparseRetriever:
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self._index_cache = {} # 缓存 BM25 索引
|
||||
@@ -64,6 +67,7 @@ class SparseRetriever:
|
||||
|
||||
Returns:
|
||||
List[SparseResult]: 检索结果列表
|
||||
|
||||
"""
|
||||
# 1. 获取所有相关块
|
||||
top_k_sparse = 0
|
||||
@@ -73,7 +77,9 @@ class SparseRetriever:
|
||||
if not vec_db:
|
||||
continue
|
||||
result = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={}, limit=None, offset=None
|
||||
metadata_filters={},
|
||||
limit=None,
|
||||
offset=None,
|
||||
)
|
||||
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
|
||||
result = [
|
||||
@@ -122,7 +128,7 @@ class SparseRetriever:
|
||||
kb_id=chunk["kb_id"],
|
||||
content=chunk["text"],
|
||||
score=float(score),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
results.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
||||
"""日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
||||
|
||||
const:
|
||||
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
|
||||
@@ -21,14 +20,14 @@ function:
|
||||
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
|
||||
"""
|
||||
|
||||
import logging
|
||||
import colorlog
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from collections import deque
|
||||
|
||||
import colorlog
|
||||
|
||||
# 日志缓存大小
|
||||
CACHED_SIZE = 200
|
||||
@@ -52,6 +51,7 @@ def is_plugin_path(pathname):
|
||||
|
||||
Returns:
|
||||
bool: 如果路径来自插件目录,则返回 True,否则返回 False
|
||||
|
||||
"""
|
||||
if not pathname:
|
||||
return False
|
||||
@@ -68,6 +68,7 @@ def get_short_level_name(level_name):
|
||||
|
||||
Returns:
|
||||
str: 四个字母的日志级别缩写
|
||||
|
||||
"""
|
||||
level_map = {
|
||||
"DEBUG": "DBUG",
|
||||
@@ -87,13 +88,14 @@ class LogBroker:
|
||||
|
||||
def __init__(self):
|
||||
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
||||
self.subscribers: List[Queue] = [] # 订阅者列表
|
||||
self.subscribers: list[Queue] = [] # 订阅者列表
|
||||
|
||||
def register(self) -> Queue:
|
||||
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
|
||||
|
||||
Returns:
|
||||
Queue: 订阅者的队列, 可用于接收日志消息
|
||||
|
||||
"""
|
||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||
self.subscribers.append(q)
|
||||
@@ -104,6 +106,7 @@ class LogBroker:
|
||||
|
||||
Args:
|
||||
q (Queue): 需要取消订阅的队列
|
||||
|
||||
"""
|
||||
self.subscribers.remove(q)
|
||||
|
||||
@@ -113,6 +116,7 @@ class LogBroker:
|
||||
Args:
|
||||
log_entry (dict): 日志消息, 包含日志级别和日志内容.
|
||||
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
|
||||
|
||||
"""
|
||||
self.log_cache.append(log_entry)
|
||||
for q in self.subscribers:
|
||||
@@ -138,6 +142,7 @@ class LogQueueHandler(logging.Handler):
|
||||
|
||||
Args:
|
||||
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
||||
|
||||
"""
|
||||
log_entry = self.format(record)
|
||||
self.log_broker.publish(
|
||||
@@ -145,7 +150,7 @@ class LogQueueHandler(logging.Handler):
|
||||
"level": record.levelname,
|
||||
"time": record.asctime,
|
||||
"data": log_entry,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -164,6 +169,7 @@ class LogManager:
|
||||
|
||||
Returns:
|
||||
logging.Logger: 返回配置好的日志记录器
|
||||
|
||||
"""
|
||||
logger = logging.getLogger(log_name)
|
||||
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
|
||||
@@ -171,10 +177,10 @@ class LogManager:
|
||||
return logger
|
||||
# 如果logger没有处理器
|
||||
console_handler = logging.StreamHandler(
|
||||
sys.stdout
|
||||
sys.stdout,
|
||||
) # 创建一个StreamHandler用于控制台输出
|
||||
console_handler.setLevel(
|
||||
logging.DEBUG
|
||||
logging.DEBUG,
|
||||
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||
|
||||
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
|
||||
@@ -195,7 +201,8 @@ class LogManager:
|
||||
|
||||
class FileNameFilter(logging.Filter):
|
||||
"""文件名过滤器类, 用于修改日志记录的文件名格式
|
||||
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
|
||||
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式
|
||||
"""
|
||||
|
||||
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
||||
def filter(self, record):
|
||||
@@ -231,6 +238,7 @@ class LogManager:
|
||||
Args:
|
||||
logger (logging.Logger): 日志记录器
|
||||
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
|
||||
|
||||
"""
|
||||
handler = LogQueueHandler(log_broker)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
@@ -240,7 +248,7 @@ class LogManager:
|
||||
# 为队列处理器设置相同格式的formatter
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
|
||||
)
|
||||
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s",
|
||||
),
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
|
||||
822
astrbot/core/memory/DESIGN.excalidraw
Normal file
822
astrbot/core/memory/DESIGN.excalidraw
Normal file
@@ -0,0 +1,822 @@
|
||||
{
|
||||
"type": "excalidraw",
|
||||
"version": 2,
|
||||
"source": "https://marketplace.visualstudio.com/items?itemName=pomdtr.excalidraw-editor",
|
||||
"elements": [
|
||||
{
|
||||
"id": "l6cYurMvF69IM4Kc33Qou",
|
||||
"type": "rectangle",
|
||||
"x": 173.140625,
|
||||
"y": -29.0234375,
|
||||
"width": 92.95703125,
|
||||
"height": 77.109375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a0",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1409469537,
|
||||
"version": 91,
|
||||
"versionNonce": 307958671,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "1ZvS6t8U6ihUjNU0dakgl",
|
||||
"type": "arrow",
|
||||
"x": 409.30859375,
|
||||
"y": 9.6875,
|
||||
"width": 118.2734375,
|
||||
"height": 1.9609375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a1",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 326508865,
|
||||
"version": 120,
|
||||
"versionNonce": 199367023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
-118.2734375,
|
||||
-1.9609375
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "tfdUGiJdcMoOHGfqFHXK6",
|
||||
"type": "text",
|
||||
"x": 153.46875,
|
||||
"y": -70.9765625,
|
||||
"width": 136.4598846435547,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a2",
|
||||
"roundness": null,
|
||||
"seed": 688712865,
|
||||
"version": 67,
|
||||
"versionNonce": 300660705,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703743816,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FAISS+SQLite",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FAISS+SQLite",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "AeL3kEB9a8_TAvAXpAbpl",
|
||||
"type": "text",
|
||||
"x": 438.36328125,
|
||||
"y": -3.78125,
|
||||
"width": 116.109375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a3",
|
||||
"roundness": null,
|
||||
"seed": 788579535,
|
||||
"version": 33,
|
||||
"versionNonce": 946602095,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703932431,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FACT",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FACT",
|
||||
"autoResize": false,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Pe3TeMZvxQ8tRTcbD5v6P",
|
||||
"type": "arrow",
|
||||
"x": 297.125,
|
||||
"y": 40.2578125,
|
||||
"width": 120.2421875,
|
||||
"height": 1.421875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a4",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1146229999,
|
||||
"version": 44,
|
||||
"versionNonce": 636917679,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703759050,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
120.2421875,
|
||||
1.421875
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "GhmQoadtQRK8c8aEEbYKQ",
|
||||
"type": "text",
|
||||
"x": 283.53515625,
|
||||
"y": 64.76171875,
|
||||
"width": 130.85989379882812,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a5",
|
||||
"roundness": null,
|
||||
"seed": 1445650959,
|
||||
"version": 79,
|
||||
"versionNonce": 566193167,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703768982,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "top-n Similary\n",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "top-n Similary\n",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"type": "rectangle",
|
||||
"x": 528.1586158430439,
|
||||
"y": -173.43472375183552,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a6",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 223409231,
|
||||
"version": 44,
|
||||
"versionNonce": 1066827105,
|
||||
"isDeleted": false,
|
||||
"boundElements": [
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow"
|
||||
}
|
||||
],
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "2SzqzpJ4C2ymVj8-8vN7H",
|
||||
"type": "text",
|
||||
"x": 548.1480270948795,
|
||||
"y": -211,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a7",
|
||||
"roundness": null,
|
||||
"seed": 1015608623,
|
||||
"version": 23,
|
||||
"versionNonce": 950374849,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704047884,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "CgW6Yf9v0a9q1tsjhDl7b",
|
||||
"type": "text",
|
||||
"x": 568.3099317299038,
|
||||
"y": -154.69469411681115,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aA",
|
||||
"roundness": null,
|
||||
"seed": 452254927,
|
||||
"version": 10,
|
||||
"versionNonce": 972895023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "knvlKpaFZ8lY-73Y-e9W6",
|
||||
"type": "text",
|
||||
"x": 569.11328125,
|
||||
"y": -116.91056665512056,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aB",
|
||||
"roundness": null,
|
||||
"seed": 914644015,
|
||||
"version": 90,
|
||||
"versionNonce": 158135631,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Q7URqvTSMpvj08ye-afTT",
|
||||
"type": "rectangle",
|
||||
"x": 444.515625,
|
||||
"y": 36.7890625,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aC",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1642537601,
|
||||
"version": 19,
|
||||
"versionNonce": 948406575,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703870173,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "JjxBt9cZIZXNTd6CmwyKL",
|
||||
"type": "rectangle",
|
||||
"x": 452.203125,
|
||||
"y": 46.064453125,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aD",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1746916641,
|
||||
"version": 40,
|
||||
"versionNonce": 1650978255,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703871882,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "XGBCPPFnjriqsL8LvLwyQ",
|
||||
"type": "rectangle",
|
||||
"x": 461.56640625,
|
||||
"y": 56.162109375,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aE",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 529794575,
|
||||
"version": 85,
|
||||
"versionNonce": 2131900641,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703874182,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow",
|
||||
"x": 537.6875,
|
||||
"y": 48.203125,
|
||||
"width": 6.615850226297994,
|
||||
"height": 75.81335873223107,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aF",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1982870689,
|
||||
"version": 90,
|
||||
"versionNonce": 25307457,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
6.615850226297994,
|
||||
-75.81335873223107
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": {
|
||||
"elementId": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"focus": 0.6071885090336794,
|
||||
"gap": 24.64453125
|
||||
},
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "jgJgqGMRWcaNX_28wY4CU",
|
||||
"type": "text",
|
||||
"x": 570,
|
||||
"y": 10,
|
||||
"width": 67.11994934082031,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aG",
|
||||
"roundness": null,
|
||||
"seed": 1065220559,
|
||||
"version": 26,
|
||||
"versionNonce": 2115991521,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703959397,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "update",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "update",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "_5pSPPOpp9h1TpFCIc055",
|
||||
"type": "text",
|
||||
"x": 292.36328125,
|
||||
"y": -138.5703125,
|
||||
"width": 122.87992858886719,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aH",
|
||||
"roundness": null,
|
||||
"seed": 51461025,
|
||||
"version": 26,
|
||||
"versionNonce": 1647492655,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703925147,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "ADD Memory",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "ADD Memory",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "YG6MdL14l7lk4ypQNMZ_k",
|
||||
"type": "text",
|
||||
"x": 296.71885397566257,
|
||||
"y": 161.399157096715,
|
||||
"width": 295.27984619140625,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aJ",
|
||||
"roundness": null,
|
||||
"seed": 1183210273,
|
||||
"version": 122,
|
||||
"versionNonce": 1702733281,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704085083,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RETRIEVE Memory (STATIC)",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RETRIEVE Memory (STATIC)",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Foa3VPJYqhj1uAX5mn3n0",
|
||||
"type": "rectangle",
|
||||
"x": 324.7616636099071,
|
||||
"y": 248.63213980937013,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aL",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 995116257,
|
||||
"version": 225,
|
||||
"versionNonce": 1886900225,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "pe3veI_yBFKYtbaJwDKQT",
|
||||
"type": "text",
|
||||
"x": 344.7510748617428,
|
||||
"y": 211.06686356120565,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aM",
|
||||
"roundness": null,
|
||||
"seed": 26673345,
|
||||
"version": 204,
|
||||
"versionNonce": 1004546017,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "bOlhO8AaKE86_43viu5UG",
|
||||
"type": "text",
|
||||
"x": 365.50408375566445,
|
||||
"y": 269.24725381983865,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aN",
|
||||
"roundness": null,
|
||||
"seed": 1849784033,
|
||||
"version": 106,
|
||||
"versionNonce": 762320737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "V_iDW10PKwMe7vWb5S5HF",
|
||||
"type": "text",
|
||||
"x": 366.3074332757606,
|
||||
"y": 307.03138128152926,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aO",
|
||||
"roundness": null,
|
||||
"seed": 1670509249,
|
||||
"version": 186,
|
||||
"versionNonce": 1964540737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "LHKMRdSowgcl2LsKacxTz",
|
||||
"type": "text",
|
||||
"x": 484.9493410573871,
|
||||
"y": 292.45619471187945,
|
||||
"width": 273.579833984375,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aP",
|
||||
"roundness": null,
|
||||
"seed": 945666991,
|
||||
"version": 104,
|
||||
"versionNonce": 1512137505,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704096016,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
}
|
||||
],
|
||||
"appState": {
|
||||
"gridSize": 20,
|
||||
"gridStep": 5,
|
||||
"gridModeEnabled": false,
|
||||
"viewBackgroundColor": "#ffffff"
|
||||
},
|
||||
"files": {}
|
||||
}
|
||||
76
astrbot/core/memory/_README.md
Normal file
76
astrbot/core/memory/_README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
## Decay Score
|
||||
|
||||
记忆衰减分数定义为:
|
||||
|
||||
\[
|
||||
\text{decay\_score}
|
||||
= \alpha \cdot e^{-\lambda \cdot \Delta t \cdot \beta}
|
||||
|
||||
+ (1-\alpha)\cdot (1 - e^{-\gamma \cdot c})
|
||||
\]
|
||||
|
||||
其中:
|
||||
|
||||
+ \(\Delta t\):自上次检索以来经过的时间(天),由 `last_retrieval_at` 计算;
|
||||
+ \(c\):检索次数,对应字段 `retrieval_count`;
|
||||
+ \(\alpha\):控制时间衰减和检索次数影响的权重;
|
||||
+ \(\gamma\):控制检索次数影响的速率;
|
||||
+ \(\lambda\):控制时间衰减的速率;
|
||||
+ \(\beta\):时间衰减调节因子;
|
||||
|
||||
\[
|
||||
\beta = \frac{1}{1 + a \cdot c}
|
||||
\]
|
||||
|
||||
+ \(a\):控制检索次数对时间衰减影响的权重。
|
||||
|
||||
## ADD MEMORY
|
||||
|
||||
+ LLM 通过 `astr_add_memory` 工具调用,传入记忆内容和记忆类型。
|
||||
+ 生成 `mem_id = uuid4()`。
|
||||
+ 从上下文中获取 `owner_id = unified_message_origin`。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 使用 VecDB 以新记忆内容为 query,检索前 20 条相似记忆。
|
||||
2. 从中取相似度最高的前 5 条:
|
||||
+ 若相似度超过“合并阈值”(如 `sim >= merge_threshold`):
|
||||
+ 将该条记忆视为同一记忆,使用 LLM 将旧内容与新内容合并;
|
||||
+ 在同一个 `mem_id` 上更新 MemoryDB 和 VecDB(UPDATE,而非新建)。
|
||||
+ 否则:
|
||||
+ 作为全新的记忆插入:
|
||||
+ 写入 VecDB(metadata 中包含 `mem_id`, `owner_id`);
|
||||
+ 写入 MemoryDB 的 `memory_chunks` 表,初始化:
|
||||
+ `created_at = now`
|
||||
+ `last_retrieval_at = now`
|
||||
+ `retrieval_count = 1` 等。
|
||||
3. 对 VecDB 返回的前 20 条记忆,如果相似度高于某个“赫布阈值”(`hebb_threshold`),则:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
这一步体现了赫布学习:与新记忆共同被激活的旧记忆会获得一次强化。
|
||||
|
||||
## QUERY MEMORY (STATIC)
|
||||
|
||||
+ LLM 通过 `astr_query_memory` 工具调用,无参数。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 从 MemoryDB 的 `memory_chunks` 表中查询当前用户所有活跃记忆:
|
||||
+ `SELECT * FROM memory_chunks WHERE owner_id = ? AND is_active = 1`
|
||||
2. 对每条记忆,根据 `last_retrieval_at` 和 `retrieval_count` 计算对应的 `decay_score`。
|
||||
3. 按 `decay_score` 从高到低排序,返回前 `top_k` 条记忆内容给 LLM。
|
||||
4. 对返回的这 `top_k` 条记忆:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
## QUERY MEMORY (DYNAMIC)(暂不实现)
|
||||
|
||||
+ LLM 提供查询内容作为语义 query。
|
||||
+ 使用 VecDB 检索与该 query 最相似的前 `N` 条记忆(`N > top_k`)。
|
||||
+ 根据 `mem_id` 从 `memory_chunks` 中加载对应记录。
|
||||
+ 对这批候选记忆计算:
|
||||
+ 语义相似度(来自 VecDB)
|
||||
+ `decay_score`
|
||||
+ 最终排序分数(例如 `w1 * sim + w2 * decay_score`)
|
||||
+ 按最终排序分数从高到低返回前 `top_k` 条记忆内容,并更新它们的 `retrieval_count` 和 `last_retrieval_at`。
|
||||
63
astrbot/core/memory/entities.py
Normal file
63
astrbot/core/memory/entities.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
from sqlmodel import Field, MetaData, SQLModel
|
||||
|
||||
MEMORY_TYPE_IMPORTANCE = {"persona": 1.3, "fact": 1.0, "ephemeral": 0.8}
|
||||
|
||||
|
||||
class BaseMemoryModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class MemoryChunk(BaseMemoryModel, table=True):
|
||||
"""A chunk of memory stored in the system."""
|
||||
|
||||
__tablename__ = "memory_chunks" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
mem_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
fact: str = Field(nullable=False)
|
||||
"""The factual content of the memory chunk."""
|
||||
owner_id: str = Field(max_length=255, nullable=False, index=True)
|
||||
"""The identifier of the owner (user) of the memory chunk."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The timestamp when the memory chunk was created."""
|
||||
last_retrieval_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
"""The timestamp when the memory chunk was last retrieved."""
|
||||
retrieval_count: int = Field(default=1, nullable=False)
|
||||
"""The number of times the memory chunk has been retrieved."""
|
||||
memory_type: str = Field(max_length=20, nullable=False, default="fact")
|
||||
"""The type of memory (e.g., 'persona', 'fact', 'ephemeral')."""
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
"""Whether the memory chunk is active."""
|
||||
|
||||
def compute_decay_score(self, current_time: datetime) -> float:
|
||||
"""Compute the decay score of the memory chunk based on time and retrievals."""
|
||||
# Constants for the decay formula
|
||||
alpha = 0.5
|
||||
gamma = 0.1
|
||||
lambda_ = 0.05
|
||||
a = 0.1
|
||||
|
||||
# Calculate delta_t in days
|
||||
delta_t = (current_time - self.last_retrieval_at).total_seconds() / 86400
|
||||
c = self.retrieval_count
|
||||
beta = 1 / (1 + a * c)
|
||||
decay_score = alpha * np.exp(-lambda_ * delta_t * beta) + (1 - alpha) * (
|
||||
1 - np.exp(-gamma * c)
|
||||
)
|
||||
return decay_score * MEMORY_TYPE_IMPORTANCE.get(self.memory_type, 1.0)
|
||||
174
astrbot/core/memory/mem_db_sqlite.py
Normal file
174
astrbot/core/memory/mem_db_sqlite.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import col
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
from .entities import BaseMemoryModel, MemoryChunk
|
||||
|
||||
|
||||
class MemoryDatabase:
|
||||
def __init__(self, db_path: str = "data/astr_memory/memory.db") -> None:
|
||||
"""Initialize memory database
|
||||
|
||||
Args:
|
||||
db_path: Database file path, default is data/astr_memory/memory.db
|
||||
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# Ensure directory exists
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create async engine
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""Get database session
|
||||
|
||||
Usage:
|
||||
async with mem_db.get_db() as session:
|
||||
# Perform database operations
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize database, create tables and configure SQLite parameters"""
|
||||
async with self.engine.begin() as conn:
|
||||
# Create all memory related tables
|
||||
await conn.run_sync(BaseMemoryModel.metadata.create_all)
|
||||
|
||||
# Configure SQLite performance optimization parameters
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
await self._create_indexes()
|
||||
self.inited = True
|
||||
logger.info(f"Memory database initialized: {self.db_path}")
|
||||
|
||||
async def _create_indexes(self) -> None:
|
||||
"""Create indexes for memory_chunks table"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
# Create memory chunks table indexes
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_mem_id "
|
||||
"ON memory_chunks(mem_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_id "
|
||||
"ON memory_chunks(owner_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_active "
|
||||
"ON memory_chunks(owner_id, is_active)",
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"Memory database closed: {self.db_path}")
|
||||
|
||||
async def insert_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Insert a new memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_memory_by_id(self, mem_id: str) -> MemoryChunk | None:
|
||||
"""Get memory chunk by mem_id"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(col(MemoryChunk.mem_id) == mem_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Update an existing memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_active_memories(self, owner_id: str) -> list[MemoryChunk]:
|
||||
"""Get all active memories for a user"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(
|
||||
col(MemoryChunk.owner_id) == owner_id,
|
||||
col(MemoryChunk.is_active) == True, # noqa: E712
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_retrieval_stats(
|
||||
self,
|
||||
mem_ids: list[str],
|
||||
current_time: datetime | None = None,
|
||||
) -> None:
|
||||
"""Update retrieval statistics for multiple memories"""
|
||||
if not mem_ids:
|
||||
return
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id).in_(mem_ids))
|
||||
.values(
|
||||
retrieval_count=MemoryChunk.retrieval_count + 1,
|
||||
last_retrieval_at=current_time,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def deactivate_memory(self, mem_id: str) -> bool:
|
||||
"""Deactivate a memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id) == mem_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0 if result.rowcount else False # type: ignore
|
||||
281
astrbot/core/memory/memory_manager.py
Normal file
281
astrbot/core/memory/memory_manager.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import Provider as LLMProvider
|
||||
|
||||
from .entities import MemoryChunk
|
||||
from .mem_db_sqlite import MemoryDatabase
|
||||
|
||||
MERGE_THRESHOLD = 0.85
|
||||
"""Similarity threshold for merging memories"""
|
||||
HEBB_THRESHOLD = 0.70
|
||||
"""Similarity threshold for Hebbian learning reinforcement"""
|
||||
MERGE_SYSTEM_PROMPT = """You are a memory consolidation assistant. Your task is to merge two related memory entries into a single, comprehensive memory.
|
||||
|
||||
Input format:
|
||||
- Old memory: [existing memory content]
|
||||
- New memory: [new memory content to be integrated]
|
||||
|
||||
Your output should be a single, concise memory that combines the essential information from both entries. Preserve specific details, update outdated information, and eliminate redundancy. Output only the merged memory content without any explanations or meta-commentary."""
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""Manager for user long-term memory storage and retrieval"""
|
||||
|
||||
def __init__(self, memory_root_dir: str = "data/astr_memory"):
|
||||
self.memory_root_dir = Path(memory_root_dir)
|
||||
self.memory_root_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.mem_db: MemoryDatabase | None = None
|
||||
self.vec_db: FaissVecDB | None = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
merge_llm_provider: LLMProvider,
|
||||
):
|
||||
"""Initialize memory database and vector database"""
|
||||
# Initialize MemoryDB
|
||||
db_path = self.memory_root_dir / "memory.db"
|
||||
self.mem_db = MemoryDatabase(db_path.as_posix())
|
||||
await self.mem_db.initialize()
|
||||
|
||||
self.embedding_provider = embedding_provider
|
||||
self.merge_llm_provider = merge_llm_provider
|
||||
|
||||
# Initialize VecDB
|
||||
doc_store_path = self.memory_root_dir / "doc.db"
|
||||
index_store_path = self.memory_root_dir / "index.faiss"
|
||||
self.vec_db = FaissVecDB(
|
||||
doc_store_path=doc_store_path.as_posix(),
|
||||
index_store_path=index_store_path.as_posix(),
|
||||
embedding_provider=self.embedding_provider,
|
||||
)
|
||||
await self.vec_db.initialize()
|
||||
|
||||
logger.info("Memory manager initialized")
|
||||
self._initialized = True
|
||||
|
||||
async def terminate(self):
|
||||
"""Close all database connections"""
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
if self.mem_db:
|
||||
await self.mem_db.close()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
fact: str,
|
||||
owner_id: str,
|
||||
memory_type: str = "fact",
|
||||
) -> MemoryChunk:
|
||||
"""Add a new memory with similarity check and merge logic
|
||||
|
||||
Implements the ADD MEMORY workflow from _README.md:
|
||||
1. Search for similar memories using VecDB
|
||||
2. If similarity >= merge_threshold, merge with existing memory
|
||||
3. Otherwise, create new memory
|
||||
4. Apply Hebbian learning to similar memories (similarity >= hebb_threshold)
|
||||
|
||||
Args:
|
||||
fact: Memory content
|
||||
owner_id: User identifier
|
||||
memory_type: Memory type ('persona', 'fact', 'ephemeral')
|
||||
|
||||
Returns:
|
||||
The created or updated MemoryChunk
|
||||
|
||||
"""
|
||||
if not self.vec_db or not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Search for similar memories
|
||||
similar_results = await self.vec_db.retrieve(
|
||||
query=fact,
|
||||
k=20,
|
||||
fetch_k=50,
|
||||
metadata_filters={"owner_id": owner_id},
|
||||
)
|
||||
|
||||
# Step 2: Check if we should merge with existing memories (top 3 similar ones)
|
||||
merge_candidates = [
|
||||
r for r in similar_results[:3] if r.similarity >= MERGE_THRESHOLD
|
||||
]
|
||||
|
||||
if merge_candidates:
|
||||
# Get all candidate memories from database
|
||||
candidate_memories: list[tuple[str, MemoryChunk]] = []
|
||||
for candidate in merge_candidates:
|
||||
mem_id = json.loads(candidate.data["metadata"])["mem_id"]
|
||||
memory = await self.mem_db.get_memory_by_id(mem_id)
|
||||
if memory:
|
||||
candidate_memories.append((mem_id, memory))
|
||||
|
||||
if candidate_memories:
|
||||
# Use the most similar memory as the base
|
||||
base_mem_id, base_memory = candidate_memories[0]
|
||||
|
||||
# Collect all facts to merge (existing candidates + new fact)
|
||||
all_facts = [mem.fact for _, mem in candidate_memories] + [fact]
|
||||
merged_fact = await self._merge_multiple_memories(all_facts)
|
||||
|
||||
# Update the base memory
|
||||
base_memory.fact = merged_fact
|
||||
base_memory.last_retrieval_at = current_time
|
||||
base_memory.retrieval_count += 1
|
||||
updated_memory = await self.mem_db.update_memory(base_memory)
|
||||
|
||||
# Update VecDB for base memory
|
||||
await self.vec_db.delete(base_mem_id)
|
||||
await self.vec_db.insert(
|
||||
content=merged_fact,
|
||||
metadata={
|
||||
"mem_id": base_mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=base_mem_id,
|
||||
)
|
||||
|
||||
# Deactivate and remove other merged memories
|
||||
for mem_id, _ in candidate_memories[1:]:
|
||||
await self.mem_db.deactivate_memory(mem_id)
|
||||
await self.vec_db.delete(mem_id)
|
||||
|
||||
logger.info(
|
||||
f"Merged {len(candidate_memories)} memories into {base_mem_id} for user {owner_id}"
|
||||
)
|
||||
return updated_memory
|
||||
|
||||
# Step 3: Create new memory
|
||||
mem_id = str(uuid.uuid4())
|
||||
new_memory = MemoryChunk(
|
||||
mem_id=mem_id,
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
created_at=current_time,
|
||||
last_retrieval_at=current_time,
|
||||
retrieval_count=1,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Insert into MemoryDB
|
||||
created_memory = await self.mem_db.insert_memory(new_memory)
|
||||
|
||||
# Insert into VecDB
|
||||
await self.vec_db.insert(
|
||||
content=fact,
|
||||
metadata={
|
||||
"mem_id": mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=mem_id,
|
||||
)
|
||||
|
||||
# Step 4: Apply Hebbian learning to similar memories
|
||||
hebb_mem_ids = [
|
||||
json.loads(r.data["metadata"])["mem_id"]
|
||||
for r in similar_results
|
||||
if r.similarity >= HEBB_THRESHOLD
|
||||
]
|
||||
if hebb_mem_ids:
|
||||
await self.mem_db.update_retrieval_stats(hebb_mem_ids, current_time)
|
||||
logger.debug(
|
||||
f"Applied Hebbian learning to {len(hebb_mem_ids)} memories for user {owner_id}",
|
||||
)
|
||||
|
||||
logger.info(f"Created new memory {mem_id} for user {owner_id}")
|
||||
return created_memory
|
||||
|
||||
async def query_memory(
|
||||
self,
|
||||
owner_id: str,
|
||||
top_k: int = 5,
|
||||
) -> list[MemoryChunk]:
|
||||
"""Query user's memories using static retrieval with decay score ranking
|
||||
|
||||
Implements the QUERY MEMORY (STATIC) workflow from _README.md:
|
||||
1. Get all active memories for user from MemoryDB
|
||||
2. Compute decay_score for each memory
|
||||
3. Sort by decay_score and return top_k
|
||||
4. Update retrieval statistics for returned memories
|
||||
|
||||
Args:
|
||||
owner_id: User identifier
|
||||
top_k: Number of memories to return
|
||||
|
||||
Returns:
|
||||
List of top_k MemoryChunk sorted by decay score
|
||||
"""
|
||||
if not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Get all active memories for user
|
||||
all_memories = await self.mem_db.get_active_memories(owner_id)
|
||||
|
||||
if not all_memories:
|
||||
return []
|
||||
|
||||
# Step 2-3: Compute decay scores and sort
|
||||
memories_with_scores = [
|
||||
(mem, mem.compute_decay_score(current_time)) for mem in all_memories
|
||||
]
|
||||
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get top_k memories
|
||||
top_memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
||||
|
||||
# Step 4: Update retrieval statistics
|
||||
mem_ids = [mem.mem_id for mem in top_memories]
|
||||
await self.mem_db.update_retrieval_stats(mem_ids, current_time)
|
||||
|
||||
logger.debug(f"Retrieved {len(top_memories)} memories for user {owner_id}")
|
||||
return top_memories
|
||||
|
||||
async def _merge_multiple_memories(self, facts: list[str]) -> str:
|
||||
"""Merge multiple memory facts using LLM in one call
|
||||
|
||||
Args:
|
||||
facts: List of memory facts to merge
|
||||
|
||||
Returns:
|
||||
Merged memory content
|
||||
"""
|
||||
if not self.merge_llm_provider:
|
||||
return " ".join(facts)
|
||||
|
||||
if len(facts) == 1:
|
||||
return facts[0]
|
||||
|
||||
try:
|
||||
# Format all facts as a numbered list
|
||||
facts_list = "\n".join(f"{i + 1}. {fact}" for i, fact in enumerate(facts))
|
||||
user_prompt = (
|
||||
f"Please merge the following {len(facts)} related memory entries "
|
||||
"into a single, comprehensive memory:"
|
||||
f"\n{facts_list}\n\nOutput only the merged memory content."
|
||||
)
|
||||
response = await self.merge_llm_provider.text_chat(
|
||||
prompt=user_prompt,
|
||||
system_prompt=MERGE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
merged_content = response.completion_text.strip()
|
||||
return merged_content if merged_content else " ".join(facts)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to merge memories with LLM: {e}, using fallback")
|
||||
return " ".join(facts)
|
||||
156
astrbot/core/memory/tools.py
Normal file
156
astrbot/core/memory/tools.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext, ContextWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for adding memories to user's long-term memory storage"""
|
||||
|
||||
name: str = "astr_add_memory"
|
||||
description: str = (
|
||||
"Add a new memory to the user's long-term memory storage. "
|
||||
"Use this tool only when the user explicitly asks you to remember something, "
|
||||
"or when they share stable preferences, identity, or long-term goals that will be useful in future interactions."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fact": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The concrete memory content to store, such as a user preference, "
|
||||
"identity detail, long-term goal, or stable profile fact."
|
||||
),
|
||||
},
|
||||
"memory_type": {
|
||||
"type": "string",
|
||||
"enum": ["persona", "fact", "ephemeral"],
|
||||
"description": (
|
||||
"The relative importance of this memory. "
|
||||
"Use 'persona' for core identity or highly impactful information, "
|
||||
"'fact' for normal long-term preferences, "
|
||||
"and 'ephemeral' for minor or tentative facts."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["fact", "memory_type"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Add a memory to long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Must contain 'fact' and 'memory_type'
|
||||
|
||||
Returns:
|
||||
ToolExecResult with success message
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
fact = kwargs.get("fact")
|
||||
memory_type = kwargs.get("memory_type", "fact")
|
||||
|
||||
if not fact:
|
||||
return "Missing required parameter: fact"
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Add memory using memory manager
|
||||
memory = await mm.add_memory(
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
)
|
||||
|
||||
return f"Memory added successfully (ID: {memory.mem_id})"
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to add memory: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for querying user's long-term memories"""
|
||||
|
||||
name: str = "astr_query_memory"
|
||||
description: str = (
|
||||
"Query the user's long-term memory storage and return the most relevant memories. "
|
||||
"Use this tool when you need user-specific context, preferences, or past facts "
|
||||
"that are not explicitly present in the current conversation."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of memories to retrieve after retention-based ranking. "
|
||||
"Typically between 3 and 10."
|
||||
),
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Query memories from long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Optional 'top_k' parameter
|
||||
|
||||
Returns:
|
||||
ToolExecResult with formatted memory list
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Query memories using memory manager
|
||||
memories = await mm.query_memory(
|
||||
owner_id=owner_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not memories:
|
||||
return "No memories found for this user."
|
||||
|
||||
# Format memories for output
|
||||
formatted_memories = []
|
||||
for i, mem in enumerate(memories, 1):
|
||||
formatted_memories.append(
|
||||
f"{i}. [{mem.memory_type.upper()}] {mem.fact} "
|
||||
f"(retrieved {mem.retrieval_count} times, "
|
||||
f"last: {mem.last_retrieval_at.strftime('%Y-%m-%d')})"
|
||||
)
|
||||
|
||||
result_text = "Retrieved memories:\n" + "\n".join(formatted_memories)
|
||||
return result_text
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to query memories: {str(e)}"
|
||||
|
||||
|
||||
ADD_MEMORY_TOOL = AddMemory()
|
||||
QUERY_MEMORY_TOOL = QueryMemory()
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
MIT License
|
||||
"""MIT License
|
||||
|
||||
Copyright (c) 2021 Lxns-Network
|
||||
|
||||
@@ -26,7 +25,6 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import typing as T
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
@@ -38,60 +36,36 @@ from astrbot.core.utils.io import download_file, download_image_by_url, file_to_
|
||||
|
||||
|
||||
class ComponentType(str, Enum):
|
||||
Plain = "Plain" # 纯文本消息
|
||||
Face = "Face" # QQ表情
|
||||
Record = "Record" # 语音
|
||||
Video = "Video" # 视频
|
||||
At = "At" # At
|
||||
Node = "Node" # 转发消息的一个节点
|
||||
Nodes = "Nodes" # 转发消息的多个节点
|
||||
Poke = "Poke" # QQ 戳一戳
|
||||
Image = "Image" # 图片
|
||||
Reply = "Reply" # 回复
|
||||
Forward = "Forward" # 转发消息
|
||||
File = "File" # 文件
|
||||
# Basic Segment Types
|
||||
Plain = "Plain" # plain text message
|
||||
Image = "Image" # image
|
||||
Record = "Record" # audio
|
||||
Video = "Video" # video
|
||||
File = "File" # file attachment
|
||||
|
||||
# IM-specific Segment Types
|
||||
Face = "Face" # Emoji segment for Tencent QQ platform
|
||||
At = "At" # mention a user in IM apps
|
||||
Node = "Node" # a node in a forwarded message
|
||||
Nodes = "Nodes" # a forwarded message consisting of multiple nodes
|
||||
Poke = "Poke" # a poke message for Tencent QQ platform
|
||||
Reply = "Reply" # a reply message segment
|
||||
Forward = "Forward" # a forwarded message segment
|
||||
RPS = "RPS" # TODO
|
||||
Dice = "Dice" # TODO
|
||||
Shake = "Shake" # TODO
|
||||
Anonymous = "Anonymous" # TODO
|
||||
Share = "Share"
|
||||
Contact = "Contact" # TODO
|
||||
Location = "Location" # TODO
|
||||
Music = "Music"
|
||||
RedBag = "RedBag"
|
||||
Xml = "Xml"
|
||||
Json = "Json"
|
||||
CardImage = "CardImage"
|
||||
TTS = "TTS"
|
||||
Unknown = "Unknown"
|
||||
|
||||
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
|
||||
|
||||
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def toString(self):
|
||||
output = f"[CQ:{self.type.lower()}"
|
||||
for k, v in self.__dict__.items():
|
||||
if k == "type" or v is None:
|
||||
continue
|
||||
if k == "_type":
|
||||
k = "type"
|
||||
if isinstance(v, bool):
|
||||
v = 1 if v else 0
|
||||
output += ",%s=%s" % (
|
||||
k,
|
||||
str(v)
|
||||
.replace("&", "&")
|
||||
.replace(",", ",")
|
||||
.replace("[", "[")
|
||||
.replace("]", "]"),
|
||||
)
|
||||
output += "]"
|
||||
return output
|
||||
|
||||
def toDict(self):
|
||||
data = {}
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -110,18 +84,11 @@ class BaseMessageComponent(BaseModel):
|
||||
class Plain(BaseMessageComponent):
|
||||
type = ComponentType.Plain
|
||||
text: str
|
||||
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
||||
convert: bool | None = True
|
||||
|
||||
def __init__(self, text: str, convert: bool = True, **_):
|
||||
super().__init__(text=text, convert=convert, **_)
|
||||
|
||||
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
|
||||
if not self.convert:
|
||||
return self.text
|
||||
return (
|
||||
self.text.replace("&", "&").replace("[", "[").replace("]", "]")
|
||||
)
|
||||
|
||||
def toDict(self):
|
||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||
|
||||
@@ -139,17 +106,17 @@ class Face(BaseMessageComponent):
|
||||
|
||||
class Record(BaseMessageComponent):
|
||||
type = ComponentType.Record
|
||||
file: T.Optional[str] = ""
|
||||
magic: T.Optional[bool] = False
|
||||
url: T.Optional[str] = ""
|
||||
cache: T.Optional[bool] = True
|
||||
proxy: T.Optional[bool] = True
|
||||
timeout: T.Optional[int] = 0
|
||||
file: str | None = ""
|
||||
magic: bool | None = False
|
||||
url: str | None = ""
|
||||
cache: bool | None = True
|
||||
proxy: bool | None = True
|
||||
timeout: int | None = 0
|
||||
# 额外
|
||||
path: T.Optional[str]
|
||||
path: str | None
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
for k in _.keys():
|
||||
def __init__(self, file: str | None, **_):
|
||||
for k in _:
|
||||
if k == "url":
|
||||
pass
|
||||
# Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}")
|
||||
@@ -174,15 +141,16 @@ class Record(BaseMessageComponent):
|
||||
|
||||
Returns:
|
||||
str: 语音的本地路径,以绝对路径表示。
|
||||
|
||||
"""
|
||||
if not self.file:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
if self.file.startswith("file:///"):
|
||||
return self.file[8:]
|
||||
elif self.file.startswith("http"):
|
||||
if self.file.startswith("http"):
|
||||
file_path = await download_image_by_url(self.file)
|
||||
return os.path.abspath(file_path)
|
||||
elif self.file.startswith("base64://"):
|
||||
if self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
@@ -190,16 +158,16 @@ class Record(BaseMessageComponent):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
elif os.path.exists(self.file):
|
||||
if os.path.exists(self.file):
|
||||
return os.path.abspath(self.file)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
|
||||
async def convert_to_base64(self) -> str:
|
||||
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
|
||||
|
||||
Returns:
|
||||
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||
|
||||
"""
|
||||
# convert to base64
|
||||
if not self.file:
|
||||
@@ -219,14 +187,14 @@ class Record(BaseMessageComponent):
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将语音注册到文件服务。
|
||||
"""将语音注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
@@ -245,10 +213,10 @@ class Record(BaseMessageComponent):
|
||||
class Video(BaseMessageComponent):
|
||||
type = ComponentType.Video
|
||||
file: str
|
||||
cover: T.Optional[str] = ""
|
||||
c: T.Optional[int] = 2
|
||||
cover: str | None = ""
|
||||
c: int | None = 2
|
||||
# 额外
|
||||
path: T.Optional[str] = ""
|
||||
path: str | None = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
super().__init__(file=file, **_)
|
||||
@@ -268,32 +236,31 @@ class Video(BaseMessageComponent):
|
||||
|
||||
Returns:
|
||||
str: 视频的本地路径,以绝对路径表示。
|
||||
|
||||
"""
|
||||
url = self.file
|
||||
if url and url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url and url.startswith("http"):
|
||||
if url and url.startswith("http"):
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(url, video_file_path)
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
else:
|
||||
raise Exception(f"download failed: {url}")
|
||||
elif os.path.exists(url):
|
||||
raise Exception(f"download failed: {url}")
|
||||
if os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将视频注册到文件服务。
|
||||
"""将视频注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
@@ -330,8 +297,8 @@ class Video(BaseMessageComponent):
|
||||
|
||||
class At(BaseMessageComponent):
|
||||
type = ComponentType.At
|
||||
qq: T.Union[int, str] # 此处str为all时代表所有人
|
||||
name: T.Optional[str] = ""
|
||||
qq: int | str # 此处str为all时代表所有人
|
||||
name: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
@@ -371,20 +338,12 @@ class Shake(BaseMessageComponent): # TODO
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Anonymous(BaseMessageComponent): # TODO
|
||||
type = ComponentType.Anonymous
|
||||
ignore: T.Optional[bool] = False
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Share(BaseMessageComponent):
|
||||
type = ComponentType.Share
|
||||
url: str
|
||||
title: str
|
||||
content: T.Optional[str] = ""
|
||||
image: T.Optional[str] = ""
|
||||
content: str | None = ""
|
||||
image: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
@@ -393,7 +352,7 @@ class Share(BaseMessageComponent):
|
||||
class Contact(BaseMessageComponent): # TODO
|
||||
type = ComponentType.Contact
|
||||
_type: str # type 字段冲突
|
||||
id: T.Optional[int] = 0
|
||||
id: int | None = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
@@ -403,8 +362,8 @@ class Location(BaseMessageComponent): # TODO
|
||||
type = ComponentType.Location
|
||||
lat: float
|
||||
lon: float
|
||||
title: T.Optional[str] = ""
|
||||
content: T.Optional[str] = ""
|
||||
title: str | None = ""
|
||||
content: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
@@ -413,12 +372,12 @@ class Location(BaseMessageComponent): # TODO
|
||||
class Music(BaseMessageComponent):
|
||||
type = ComponentType.Music
|
||||
_type: str
|
||||
id: T.Optional[int] = 0
|
||||
url: T.Optional[str] = ""
|
||||
audio: T.Optional[str] = ""
|
||||
title: T.Optional[str] = ""
|
||||
content: T.Optional[str] = ""
|
||||
image: T.Optional[str] = ""
|
||||
id: int | None = 0
|
||||
url: str | None = ""
|
||||
audio: str | None = ""
|
||||
title: str | None = ""
|
||||
content: str | None = ""
|
||||
image: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
# for k in _.keys():
|
||||
@@ -429,18 +388,18 @@ class Music(BaseMessageComponent):
|
||||
|
||||
class Image(BaseMessageComponent):
|
||||
type = ComponentType.Image
|
||||
file: T.Optional[str] = ""
|
||||
_type: T.Optional[str] = ""
|
||||
subType: T.Optional[int] = 0
|
||||
url: T.Optional[str] = ""
|
||||
cache: T.Optional[bool] = True
|
||||
id: T.Optional[int] = 40000
|
||||
c: T.Optional[int] = 2
|
||||
file: str | None = ""
|
||||
_type: str | None = ""
|
||||
subType: int | None = 0
|
||||
url: str | None = ""
|
||||
cache: bool | None = True
|
||||
id: int | None = 40000
|
||||
c: int | None = 2
|
||||
# 额外
|
||||
path: T.Optional[str] = ""
|
||||
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
||||
path: str | None = ""
|
||||
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
def __init__(self, file: str | None, **_):
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -470,16 +429,17 @@ class Image(BaseMessageComponent):
|
||||
|
||||
Returns:
|
||||
str: 图片的本地路径,以绝对路径表示。
|
||||
|
||||
"""
|
||||
url = self.url or self.file
|
||||
if not url:
|
||||
raise ValueError("No valid file or URL provided")
|
||||
if url.startswith("file:///"):
|
||||
return url[8:]
|
||||
elif url.startswith("http"):
|
||||
if url.startswith("http"):
|
||||
image_file_path = await download_image_by_url(url)
|
||||
return os.path.abspath(image_file_path)
|
||||
elif url.startswith("base64://"):
|
||||
if url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
@@ -487,16 +447,16 @@ class Image(BaseMessageComponent):
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
elif os.path.exists(url):
|
||||
if os.path.exists(url):
|
||||
return os.path.abspath(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def convert_to_base64(self) -> str:
|
||||
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
|
||||
|
||||
Returns:
|
||||
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||
|
||||
"""
|
||||
# convert to base64
|
||||
url = self.url or self.file
|
||||
@@ -517,14 +477,14 @@ class Image(BaseMessageComponent):
|
||||
return bs64_data
|
||||
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""
|
||||
将图片注册到文件服务。
|
||||
"""将图片注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
@@ -542,42 +502,34 @@ class Image(BaseMessageComponent):
|
||||
|
||||
class Reply(BaseMessageComponent):
|
||||
type = ComponentType.Reply
|
||||
id: T.Union[str, int]
|
||||
id: str | int
|
||||
"""所引用的消息 ID"""
|
||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||
chain: list["BaseMessageComponent"] | None = []
|
||||
"""被引用的消息段列表"""
|
||||
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||
sender_id: int | None | str = 0
|
||||
"""被引用的消息对应的发送者的 ID"""
|
||||
sender_nickname: T.Optional[str] = ""
|
||||
sender_nickname: str | None = ""
|
||||
"""被引用的消息对应的发送者的昵称"""
|
||||
time: T.Optional[int] = 0
|
||||
time: int | None = 0
|
||||
"""被引用的消息发送时间"""
|
||||
message_str: T.Optional[str] = ""
|
||||
message_str: str | None = ""
|
||||
"""被引用的消息解析后的纯文本消息字符串"""
|
||||
|
||||
text: T.Optional[str] = ""
|
||||
text: str | None = ""
|
||||
"""deprecated"""
|
||||
qq: T.Optional[int] = 0
|
||||
qq: int | None = 0
|
||||
"""deprecated"""
|
||||
seq: T.Optional[int] = 0
|
||||
seq: int | None = 0
|
||||
"""deprecated"""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class RedBag(BaseMessageComponent):
|
||||
type = ComponentType.RedBag
|
||||
title: str
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type: str = ComponentType.Poke
|
||||
id: T.Optional[int] = 0
|
||||
qq: T.Optional[int] = 0
|
||||
id: int | None = 0
|
||||
qq: int | None = 0
|
||||
|
||||
def __init__(self, type: str, **_):
|
||||
type = f"Poke:{type}"
|
||||
@@ -596,12 +548,12 @@ class Node(BaseMessageComponent):
|
||||
"""群合并转发消息"""
|
||||
|
||||
type = ComponentType.Node
|
||||
id: T.Optional[int] = 0 # 忽略
|
||||
name: T.Optional[str] = "" # qq昵称
|
||||
uin: T.Optional[str] = "0" # qq号
|
||||
content: T.Optional[list[BaseMessageComponent]] = []
|
||||
seq: T.Optional[T.Union[str, list]] = "" # 忽略
|
||||
time: T.Optional[int] = 0 # 忽略
|
||||
id: int | None = 0 # 忽略
|
||||
name: str | None = "" # qq昵称
|
||||
uin: str | None = "0" # qq号
|
||||
content: list[BaseMessageComponent] | None = []
|
||||
seq: str | list | None = "" # 忽略
|
||||
time: int | None = 0 # 忽略
|
||||
|
||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||
if isinstance(content, Node):
|
||||
@@ -619,7 +571,7 @@ class Node(BaseMessageComponent):
|
||||
{
|
||||
"type": comp.type.lower(),
|
||||
"data": {"file": f"base64://{bs64}"},
|
||||
}
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Plain):
|
||||
# For Plain segments, we need to handle the plain differently
|
||||
@@ -648,9 +600,9 @@ class Node(BaseMessageComponent):
|
||||
|
||||
class Nodes(BaseMessageComponent):
|
||||
type = ComponentType.Nodes
|
||||
nodes: T.List[Node]
|
||||
nodes: list[Node]
|
||||
|
||||
def __init__(self, nodes: T.List[Node], **_):
|
||||
def __init__(self, nodes: list[Node], **_):
|
||||
super().__init__(nodes=nodes, **_)
|
||||
|
||||
def toDict(self):
|
||||
@@ -672,19 +624,10 @@ class Nodes(BaseMessageComponent):
|
||||
return ret
|
||||
|
||||
|
||||
class Xml(BaseMessageComponent):
|
||||
type = ComponentType.Xml
|
||||
data: str
|
||||
resid: T.Optional[int] = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type = ComponentType.Json
|
||||
data: T.Union[str, dict]
|
||||
resid: T.Optional[int] = 0
|
||||
data: str | dict
|
||||
resid: int | None = 0
|
||||
|
||||
def __init__(self, data, **_):
|
||||
if isinstance(data, dict):
|
||||
@@ -692,50 +635,18 @@ class Json(BaseMessageComponent):
|
||||
super().__init__(data=data, **_)
|
||||
|
||||
|
||||
class CardImage(BaseMessageComponent):
|
||||
type = ComponentType.CardImage
|
||||
file: str
|
||||
cache: T.Optional[bool] = True
|
||||
minwidth: T.Optional[int] = 400
|
||||
minheight: T.Optional[int] = 400
|
||||
maxwidth: T.Optional[int] = 500
|
||||
maxheight: T.Optional[int] = 500
|
||||
source: T.Optional[str] = ""
|
||||
icon: T.Optional[str] = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
@staticmethod
|
||||
def fromFileSystem(path, **_):
|
||||
return CardImage(file=f"file:///{os.path.abspath(path)}", **_)
|
||||
|
||||
|
||||
class TTS(BaseMessageComponent):
|
||||
type = ComponentType.TTS
|
||||
text: str
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Unknown(BaseMessageComponent):
|
||||
type = ComponentType.Unknown
|
||||
text: str
|
||||
|
||||
def toString(self):
|
||||
return ""
|
||||
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
"""
|
||||
文件消息段
|
||||
"""
|
||||
"""文件消息段"""
|
||||
|
||||
type = ComponentType.File
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file_: T.Optional[str] = "" # 本地路径
|
||||
url: T.Optional[str] = "" # url
|
||||
name: str | None = "" # 名字
|
||||
file_: str | None = "" # 本地路径
|
||||
url: str | None = "" # url
|
||||
|
||||
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||
"""文件消息段。"""
|
||||
@@ -743,11 +654,11 @@ class File(BaseMessageComponent):
|
||||
|
||||
@property
|
||||
def file(self) -> str:
|
||||
"""
|
||||
获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||
"""获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
|
||||
"""
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
@@ -757,19 +668,16 @@ class File(BaseMessageComponent):
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段"
|
||||
)
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
@@ -777,11 +685,11 @@ class File(BaseMessageComponent):
|
||||
|
||||
@file.setter
|
||||
def file(self, value: str):
|
||||
"""
|
||||
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||
"""向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||
|
||||
Args:
|
||||
value (str): 文件路径或URL
|
||||
|
||||
"""
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
self.url = value
|
||||
@@ -796,6 +704,7 @@ class File(BaseMessageComponent):
|
||||
注意,如果为 True,也可能返回文件路径。
|
||||
Returns:
|
||||
str: 文件路径或者 http 下载链接
|
||||
|
||||
"""
|
||||
if allow_return_url and self.url:
|
||||
return self.url
|
||||
@@ -818,14 +727,14 @@ class File(BaseMessageComponent):
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
async def register_to_file_service(self):
|
||||
"""
|
||||
将文件注册到文件服务。
|
||||
"""将文件注册到文件服务。
|
||||
|
||||
Returns:
|
||||
str: 注册后的URL
|
||||
|
||||
Raises:
|
||||
Exception: 如果未配置 callback_api_base
|
||||
|
||||
"""
|
||||
callback_host = astrbot_config.get("callback_api_base")
|
||||
|
||||
@@ -863,41 +772,38 @@ class File(BaseMessageComponent):
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
type = ComponentType.WechatEmoji
|
||||
md5: T.Optional[str] = ""
|
||||
md5_len: T.Optional[int] = 0
|
||||
cdnurl: T.Optional[str] = ""
|
||||
md5: str | None = ""
|
||||
md5_len: int | None = 0
|
||||
cdnurl: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
ComponentTypes = {
|
||||
# Basic Message Segments
|
||||
"plain": Plain,
|
||||
"text": Plain,
|
||||
"face": Face,
|
||||
"image": Image,
|
||||
"record": Record,
|
||||
"video": Video,
|
||||
"file": File,
|
||||
# IM-specific Message Segments
|
||||
"face": Face,
|
||||
"at": At,
|
||||
"rps": RPS,
|
||||
"dice": Dice,
|
||||
"shake": Shake,
|
||||
"anonymous": Anonymous,
|
||||
"share": Share,
|
||||
"contact": Contact,
|
||||
"location": Location,
|
||||
"music": Music,
|
||||
"image": Image,
|
||||
"reply": Reply,
|
||||
"redbag": RedBag,
|
||||
"poke": Poke,
|
||||
"forward": Forward,
|
||||
"node": Node,
|
||||
"nodes": Nodes,
|
||||
"xml": Xml,
|
||||
"json": Json,
|
||||
"cardimage": CardImage,
|
||||
"tts": TTS,
|
||||
"unknown": Unknown,
|
||||
"file": File,
|
||||
"WechatEmoji": WechatEmoji,
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import enum
|
||||
|
||||
from typing import List, Optional, Union, AsyncGenerator
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from astrbot.core.message.components import (
|
||||
BaseMessageComponent,
|
||||
Plain,
|
||||
Image,
|
||||
At,
|
||||
AtAll,
|
||||
BaseMessageComponent,
|
||||
Image,
|
||||
Plain,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -20,18 +21,18 @@ class MessageChain:
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
|
||||
"""
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
type: Optional[str] = None
|
||||
chain: list[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: bool | None = None # None 为跟随用户设置
|
||||
type: str | None = None
|
||||
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
||||
|
||||
def message(self, message: str):
|
||||
"""添加一条文本消息到消息链 `chain` 中。
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().message("Hello ").message("world!")
|
||||
# 输出 Hello world!
|
||||
|
||||
@@ -39,11 +40,10 @@ class MessageChain:
|
||||
self.chain.append(Plain(message))
|
||||
return self
|
||||
|
||||
def at(self, name: str, qq: Union[str, int]):
|
||||
def at(self, name: str, qq: str | int):
|
||||
"""添加一条 At 消息到消息链 `chain` 中。
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().at("张三", "12345678910")
|
||||
# 输出 @张三
|
||||
|
||||
@@ -55,7 +55,6 @@ class MessageChain:
|
||||
"""添加一条 AtAll 消息到消息链 `chain` 中。
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().at_all()
|
||||
# 输出 @所有人
|
||||
|
||||
@@ -68,7 +67,6 @@ class MessageChain:
|
||||
"""添加一条错误消息到消息链 `chain` 中
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().error("解析失败")
|
||||
|
||||
"""
|
||||
@@ -82,7 +80,6 @@ class MessageChain:
|
||||
如果需要发送本地图片,请使用 `file_image` 方法。
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().image("https://example.com/image.jpg")
|
||||
|
||||
"""
|
||||
@@ -96,6 +93,7 @@ class MessageChain:
|
||||
如果需要发送网络图片,请使用 `url_image` 方法。
|
||||
|
||||
CommandResult().image("image.jpg")
|
||||
|
||||
"""
|
||||
self.chain.append(Image.fromFileSystem(path))
|
||||
return self
|
||||
@@ -114,6 +112,7 @@ class MessageChain:
|
||||
|
||||
Args:
|
||||
use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
|
||||
"""
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
@@ -125,7 +124,7 @@ class MessageChain:
|
||||
def squash_plain(self):
|
||||
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||
if not self.chain:
|
||||
return
|
||||
return None
|
||||
|
||||
new_chain = []
|
||||
first_plain = None
|
||||
@@ -153,6 +152,7 @@ class EventResultType(enum.Enum):
|
||||
Attributes:
|
||||
CONTINUE: 事件将会继续传播
|
||||
STOP: 事件将会终止传播
|
||||
|
||||
"""
|
||||
|
||||
CONTINUE = enum.auto()
|
||||
@@ -181,17 +181,18 @@ class MessageEventResult(MessageChain):
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`result_type` (EventResultType): 事件处理的结果类型。
|
||||
|
||||
"""
|
||||
|
||||
result_type: Optional[EventResultType] = field(
|
||||
default_factory=lambda: EventResultType.CONTINUE
|
||||
result_type: EventResultType | None = field(
|
||||
default_factory=lambda: EventResultType.CONTINUE,
|
||||
)
|
||||
|
||||
result_content_type: Optional[ResultContentType] = field(
|
||||
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
||||
result_content_type: ResultContentType | None = field(
|
||||
default_factory=lambda: ResultContentType.GENERAL_RESULT,
|
||||
)
|
||||
|
||||
async_stream: Optional[AsyncGenerator] = None
|
||||
async_stream: AsyncGenerator | None = None
|
||||
"""异步流"""
|
||||
|
||||
def stop_event(self) -> "MessageEventResult":
|
||||
@@ -205,9 +206,7 @@ class MessageEventResult(MessageChain):
|
||||
return self
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
"""
|
||||
是否终止事件传播。
|
||||
"""
|
||||
"""是否终止事件传播。"""
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
||||
@@ -220,6 +219,7 @@ class MessageEventResult(MessageChain):
|
||||
|
||||
Args:
|
||||
result_type (EventResultType): 事件处理的结果类型。
|
||||
|
||||
"""
|
||||
self.result_content_type = typ
|
||||
return self
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from astrbot import logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Persona, Personality
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot import logger
|
||||
|
||||
DEFAULT_PERSONALITY = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
@@ -41,12 +41,14 @@ class PersonaManager:
|
||||
return persona
|
||||
|
||||
async def get_default_persona_v3(
|
||||
self, umo: str | MessageSession | None = None
|
||||
self,
|
||||
umo: str | MessageSession | None = None,
|
||||
) -> Personality:
|
||||
"""获取默认 persona"""
|
||||
cfg = self.acm.get_conf(umo)
|
||||
default_persona_id = cfg.get("provider_settings", {}).get(
|
||||
"default_personality", "default"
|
||||
"default_personality",
|
||||
"default",
|
||||
)
|
||||
if not default_persona_id or default_persona_id == "default":
|
||||
return DEFAULT_PERSONALITY
|
||||
@@ -66,16 +68,19 @@ class PersonaManager:
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
system_prompt: str | None = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
):
|
||||
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
existing_persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not existing_persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
persona = await self.db.update_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
persona_id,
|
||||
system_prompt,
|
||||
begin_dialogs,
|
||||
tools=tools,
|
||||
)
|
||||
if persona:
|
||||
for i, p in enumerate(self.personas):
|
||||
@@ -100,7 +105,10 @@ class PersonaManager:
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} already exists.")
|
||||
new_persona = await self.db.insert_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
persona_id,
|
||||
system_prompt,
|
||||
begin_dialogs,
|
||||
tools=tools,
|
||||
)
|
||||
self.personas.append(new_persona)
|
||||
self.get_v3_persona_data()
|
||||
@@ -115,6 +123,7 @@ class PersonaManager:
|
||||
- list[dict]: 包含 persona 配置的字典列表。
|
||||
- list[Personality]: 包含 Personality 对象的列表。
|
||||
- Personality: 默认选择的 Personality 对象。
|
||||
|
||||
"""
|
||||
v3_persona_config = [
|
||||
{
|
||||
@@ -136,7 +145,7 @@ class PersonaManager:
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。",
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
@@ -146,7 +155,7 @@ class PersonaManager:
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
},
|
||||
)
|
||||
user_turn = not user_turn
|
||||
|
||||
|
||||
@@ -27,15 +27,15 @@ STAGES_ORDER = [
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"EventResultType",
|
||||
"MessageEventResult",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
"RateLimitStage",
|
||||
"RespondStage",
|
||||
"MessageEventResult",
|
||||
"EventResultType",
|
||||
"ResultDecorateStage",
|
||||
"SessionStatusCheckStage",
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user