Compare commits
109 Commits
v4.5.2
...
feat/memor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c76b7ec387 | ||
|
|
b7f3010d72 | ||
|
|
fbbaf1cd08 | ||
|
|
9c8025acce | ||
|
|
4e2154feb7 | ||
|
|
604958898c | ||
|
|
a093f5ad0a | ||
|
|
a7e9a7f30c | ||
|
|
98c5466b5d | ||
|
|
6345ac6ff8 | ||
|
|
5bcd683012 | ||
|
|
eaa193c6c5 | ||
|
|
1bdcaa1318 | ||
|
|
6b6c48354d | ||
|
|
774efb2fe0 | ||
|
|
3ec76636f9 | ||
|
|
283810d103 | ||
|
|
81a76bc8e5 | ||
|
|
788764be02 | ||
|
|
802ab26934 | ||
|
|
6857c81a14 | ||
|
|
a6ed511a30 | ||
|
|
44c2b58206 | ||
|
|
0e2adab3fd | ||
|
|
0fe87d6b98 | ||
|
|
31ef3d1084 | ||
|
|
5d1e9de096 | ||
|
|
89da4eb747 | ||
|
|
8899a1dee1 | ||
|
|
384a687ec3 | ||
|
|
70cfdd2f8b | ||
|
|
bdbd2f009a | ||
|
|
164e0d26e0 | ||
|
|
cb087b5ff9 | ||
|
|
1d3928d145 | ||
|
|
6dc3d161e7 | ||
|
|
e9805ba205 | ||
|
|
d5280dcd88 | ||
|
|
67a9663eff | ||
|
|
77dd89b8eb | ||
|
|
8e511bf14b | ||
|
|
164a4226ea | ||
|
|
6d6fefc435 | ||
|
|
aa59532287 | ||
|
|
b984bb2513 | ||
|
|
8488c9aeab | ||
|
|
676f9fd4ff | ||
|
|
1935ce4700 | ||
|
|
e760956353 | ||
|
|
be3e5f3f8b | ||
|
|
cdf617feac | ||
|
|
afb56cf707 | ||
|
|
cd2556ab94 | ||
|
|
cf4a5d9ea4 | ||
|
|
0747099cac | ||
|
|
323ec29b02 | ||
|
|
ae81d70685 | ||
|
|
270c89c12f | ||
|
|
c7a58252fe | ||
|
|
47ad8c86e5 | ||
|
|
937e879e5e | ||
|
|
1ecf26eead | ||
|
|
adbb84530a | ||
|
|
6cf169f4f2 | ||
|
|
5ab9ea12c0 | ||
|
|
fd9cb703db | ||
|
|
388c1ab16d | ||
|
|
f867c2a271 | ||
|
|
605bb2cb90 | ||
|
|
5ea15dde5a | ||
|
|
3ca545c4c7 | ||
|
|
e200835074 | ||
|
|
3a90348353 | ||
|
|
5a11d8f0ee | ||
|
|
824af5eeea | ||
|
|
08ec787491 | ||
|
|
b062e83d54 | ||
|
|
17422ba9c3 | ||
|
|
6849af2bad | ||
|
|
09c3da64f9 | ||
|
|
2c8470e8ac | ||
|
|
c4ea3db73d | ||
|
|
89e79863f6 | ||
|
|
d19945009f | ||
|
|
c77256ee0e | ||
|
|
7d823af627 | ||
|
|
3957861878 | ||
|
|
6ac43c600e | ||
|
|
27af9ebb6b | ||
|
|
b360c8446e | ||
|
|
6d00717655 | ||
|
|
bb5f06498e | ||
|
|
aca5743ab6 | ||
|
|
6903032f7e | ||
|
|
1ce0ff87bd | ||
|
|
e39d6bae0b | ||
|
|
8028e9e9a6 | ||
|
|
817f20ea01 | ||
|
|
ad5579a2f4 | ||
|
|
81a689a79b | ||
|
|
1893dd8336 | ||
|
|
021ca8175b | ||
|
|
39d6207fe1 | ||
|
|
23ce687229 | ||
|
|
3715312fd2 | ||
|
|
8196922cac | ||
|
|
8089ad91da | ||
|
|
2930cc3fd8 | ||
|
|
0e841a8b25 |
@@ -1,6 +1,7 @@
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
# github actions
|
||||
.git
|
||||
.github/
|
||||
.*ignore
|
||||
# User-specific stuff
|
||||
@@ -19,4 +20,5 @@ data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
.astrbot
|
||||
.astrbot
|
||||
astrbot.lock
|
||||
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
2
.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml
vendored
@@ -16,7 +16,7 @@ body:
|
||||
|
||||
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
|
||||
不熟悉 JSON ?可以从 [此站](https://plugins.astrbot.app) 右下角提交。
|
||||
|
||||
- type: textarea
|
||||
id: plugin-info
|
||||
|
||||
44
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
44
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,46 +1,44 @@
|
||||
name: '🐛 报告 Bug'
|
||||
name: '🐛 Report Bug / 报告 Bug'
|
||||
title: '[Bug]'
|
||||
description: 提交报告帮助我们改进。
|
||||
description: Submit bug report to help us improve. / 提交报告帮助我们改进。
|
||||
labels: [ 'bug' ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
Thank you for taking the time to report this issue! Please describe your problem accurately. If possible, please provide a reproducible snippet (this will help resolve the issue more quickly). Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 发生了什么
|
||||
description: 描述你遇到的异常
|
||||
label: What happened / 发生了什么
|
||||
description: Description
|
||||
placeholder: >
|
||||
一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
Please provide a clear and specific description of what this exception is. Please note that issues that are not detailed or have no logs will be closed immediately. Thank you for your understanding. / 一个清晰且具体的描述这个异常是什么。请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 如何复现?
|
||||
label: Reproduce / 如何复现?
|
||||
description: >
|
||||
复现该问题的步骤
|
||||
The steps to reproduce the issue. / 复现该问题的步骤
|
||||
placeholder: >
|
||||
如: 1. 打开 '...'
|
||||
Example: 1. Open '...'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
|
||||
description: >
|
||||
请提供您的 AstrBot 版本和部署方式。
|
||||
label: AstrBot version, deployment method (e.g., Windows Docker Desktop deployment), provider used, and messaging platform used. / AstrBot 版本、部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
|
||||
placeholder: >
|
||||
如: 3.1.8 Docker, 3.1.7 Windows启动器
|
||||
Example: 4.5.7 Docker, 3.1.7 Windows Launcher
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统
|
||||
label: OS
|
||||
description: |
|
||||
你在哪个操作系统上遇到了这个问题?
|
||||
On which operating system did you encounter this problem? / 你在哪个操作系统上遇到了这个问题?
|
||||
multiple: false
|
||||
options:
|
||||
- 'Windows'
|
||||
@@ -53,30 +51,30 @@ body:
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 报错日志
|
||||
label: Logs / 报错日志
|
||||
description: >
|
||||
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
Please provide complete Debug-level logs, such as error logs and screenshots. Don't worry if they're long! Please note that issues with insufficient details or no logs will be closed immediately. Thank you for your understanding. / 如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!请注意,不详细 / 没有日志的 issue 会被直接关闭,谢谢理解。
|
||||
placeholder: >
|
||||
请提供完整的报错日志或截图。
|
||||
Please provide a complete error log or screenshot. / 请提供完整的报错日志或截图。
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 你愿意提交 PR 吗?
|
||||
label: Are you willing to submit a PR? / 你愿意提交 PR 吗?
|
||||
description: >
|
||||
这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
|
||||
This is not required, but we would be happy to provide guidance during the contribution process, especially if you already have a good understanding of how to implement the fix. / 这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
|
||||
options:
|
||||
- label: 是的,我愿意提交 PR!
|
||||
- label: Yes!
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
options:
|
||||
- label: >
|
||||
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
I have read and agree to abide by the project's [Code of Conduct](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "感谢您填写我们的表单!"
|
||||
value: "Thank you for filling out our form! / 感谢您填写我们的表单!"
|
||||
|
||||
31
.github/PULL_REQUEST_TEMPLATE.md
vendored
31
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,44 +1,25 @@
|
||||
<!-- 如果有的话,请指定此 PR 旨在解决的 ISSUE 编号。 -->
|
||||
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
|
||||
|
||||
fixes #XYZ
|
||||
|
||||
---
|
||||
|
||||
### Motivation / 动机
|
||||
|
||||
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
|
||||
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
|
||||
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX issue, adds YY feature)-->
|
||||
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX issue,添加了 YY 功能)-->
|
||||
|
||||
### Modifications / 改动点
|
||||
|
||||
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
|
||||
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
|
||||
|
||||
### Verification Steps / 验证步骤
|
||||
|
||||
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤(例如:1. 导航到... 2. 点击...)。-->
|
||||
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
|
||||
- [x] This is NOT a breaking change. / 这不是一个破坏性变更。
|
||||
<!-- If your changes is a breaking change, please uncheck the checkbox above -->
|
||||
|
||||
### Screenshots or Test Results / 运行截图或测试结果
|
||||
|
||||
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
|
||||
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
|
||||
|
||||
### Compatibility & Breaking Changes / 兼容性与破坏性变更
|
||||
|
||||
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
|
||||
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
|
||||
|
||||
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
|
||||
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
|
||||
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
|
||||
|
||||
---
|
||||
|
||||
### Checklist / 检查清单
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
||||
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
||||
|
||||
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
|
||||
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
|
||||
|
||||
143
.github/workflows/docker-image.yml
vendored
143
.github/workflows/docker-image.yml
vendored
@@ -3,18 +3,125 @@ name: Docker Image CI/CD
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
schedule:
|
||||
# Run at 00:00 UTC every day
|
||||
- cron: "0 0 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
publish-docker:
|
||||
build-nightly-image:
|
||||
if: github.event_name == 'schedule'
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
GHCR_OWNER: soulter
|
||||
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
|
||||
|
||||
steps:
|
||||
- name: Pull The Codes
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0 # Must be 0 so we can fetch tags
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
- name: Check for new commits today
|
||||
if: github.event_name == 'schedule'
|
||||
id: check-commits
|
||||
run: |
|
||||
# Get commits from the last 24 hours
|
||||
commits=$(git log --since="24 hours ago" --oneline)
|
||||
if [ -z "$commits" ]; then
|
||||
echo "No commits in the last 24 hours, skipping build"
|
||||
echo "has_commits=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "Found commits in the last 24 hours:"
|
||||
echo "$commits"
|
||||
echo "has_commits=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Exit if no commits
|
||||
if: github.event_name == 'schedule' && steps.check-commits.outputs.has_commits == 'false'
|
||||
run: exit 0
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
mkdir -p dist/assets
|
||||
echo $(git rev-parse HEAD) > dist/assets/version
|
||||
cd ..
|
||||
mkdir -p data
|
||||
cp -r dashboard/dist data/
|
||||
|
||||
- name: Determine test image tags
|
||||
id: test-meta
|
||||
run: |
|
||||
short_sha=$(echo "${GITHUB_SHA}" | cut -c1-12)
|
||||
build_date=$(date +%Y%m%d)
|
||||
echo "short_sha=$short_sha" >> $GITHUB_OUTPUT
|
||||
echo "build_date=$build_date" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: env.HAS_GHCR_TOKEN == 'true'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ env.GHCR_OWNER }}
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build nightly image tags list
|
||||
id: test-tags
|
||||
run: |
|
||||
TAGS="${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-latest
|
||||
${{ env.DOCKER_HUB_USERNAME }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
|
||||
if [ "${{ env.HAS_GHCR_TOKEN }}" = "true" ]; then
|
||||
TAGS="$TAGS
|
||||
ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-latest
|
||||
ghcr.io/${{ env.GHCR_OWNER }}/astrbot:nightly-${{ steps.test-meta.outputs.build_date }}-${{ steps.test-meta.outputs.short_sha }}"
|
||||
fi
|
||||
echo "tags<<EOF" >> $GITHUB_OUTPUT
|
||||
echo "$TAGS" >> $GITHUB_OUTPUT
|
||||
echo "EOF" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build and Push Nightly Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.test-tags.outputs.tags }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Test Docker image has been built and pushed successfully"
|
||||
|
||||
build-release-image:
|
||||
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v'))
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
GHCR_OWNER: soulter
|
||||
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 1
|
||||
fetch-tag: true
|
||||
|
||||
- name: Get latest tag (only on manual trigger)
|
||||
id: get-latest-tag
|
||||
@@ -27,21 +134,22 @@ jobs:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||
|
||||
- name: Check if version is pre-release
|
||||
id: check-prerelease
|
||||
- name: Compute release metadata
|
||||
id: release-meta
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
|
||||
else
|
||||
version="${{ github.ref_name }}"
|
||||
version="${GITHUB_REF#refs/tags/}"
|
||||
fi
|
||||
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
|
||||
echo "is_prerelease=true" >> $GITHUB_OUTPUT
|
||||
echo "Version $version is a pre-release, will not push latest tag"
|
||||
echo "Version $version marked as pre-release"
|
||||
else
|
||||
echo "is_prerelease=false" >> $GITHUB_OUTPUT
|
||||
echo "Version $version is a stable release, will push latest tag"
|
||||
echo "Version $version marked as stable"
|
||||
fi
|
||||
echo "version=$version" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
@@ -67,23 +175,24 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
if: env.HAS_GHCR_TOKEN == 'true'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: Soulter
|
||||
username: ${{ env.GHCR_OWNER }}
|
||||
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and Push Docker to DockerHub and Github GHCR
|
||||
- name: Build and Push Release Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
|
||||
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||
${{ steps.release-meta.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', env.DOCKER_HUB_USERNAME) || '' }}
|
||||
${{ steps.release-meta.outputs.is_prerelease == 'false' && env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:latest', env.GHCR_OWNER) || '' }}
|
||||
${{ format('{0}/astrbot:{1}', env.DOCKER_HUB_USERNAME, steps.release-meta.outputs.version) }}
|
||||
${{ env.HAS_GHCR_TOKEN == 'true' && format('ghcr.io/{0}/astrbot:{1}', env.GHCR_OWNER, steps.release-meta.outputs.version) || '' }}
|
||||
|
||||
- name: Post build notifications
|
||||
run: echo "Docker image has been built and pushed successfully"
|
||||
run: echo "Release Docker image has been built and pushed successfully"
|
||||
|
||||
16
Dockerfile
16
Dockerfile
@@ -18,15 +18,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
|
||||
RUN curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
||||
apt-get install -y --no-install-recommends nodejs && \
|
||||
echo "3.11" > .python-version && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && apt-get install -y curl gnupg \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
|
||||
&& apt-get install -y nodejs
|
||||
|
||||
RUN python -m pip install --no-cache-dir uv && \
|
||||
uv pip install socksio pilk --no-cache-dir --system
|
||||
RUN python -m pip install uv \
|
||||
&& echo "3.11" > .python-version
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD ["uv", "run", "main.py"]
|
||||
CMD ["python", "main.py"]
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
curl \
|
||||
unzip \
|
||||
ca-certificates \
|
||||
bash \
|
||||
git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
|
||||
ENV NVM_DIR="/root/.nvm" \
|
||||
NODE_VERSION=22
|
||||
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
||||
. "$NVM_DIR/nvm.sh" && \
|
||||
nvm install $NODE_VERSION && \
|
||||
nvm use $NODE_VERSION && \
|
||||
nvm alias default $NODE_VERSION && \
|
||||
node -v && npm -v && \
|
||||
echo "3.11" > .python-version
|
||||
ENV PATH="$NVM_DIR/versions/node/v$(node -v | cut -d 'v' -f 2)/bin:$PATH"
|
||||
|
||||
RUN python -m pip install --no-cache-dir uv
|
||||
|
||||
# 安装项目依赖(根据指南,使用 uv sync)
|
||||
RUN uv sync --no-cache
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD ["uv", "run", "main.py"]
|
||||
118
README.md
118
README.md
@@ -8,7 +8,7 @@
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=1" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
@@ -42,7 +42,7 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||
|
||||
## 部署方式
|
||||
## 部署方式
|
||||
|
||||
#### Docker 部署(推荐 🥳)
|
||||
|
||||
@@ -119,83 +119,73 @@ uv run main.py
|
||||
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
## 支持的消息平台
|
||||
|
||||
**官方维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方平台) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企微应用 | ✔ |
|
||||
| 企微智能机器人 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
| 钉钉 | ✔ |
|
||||
| Slack | ✔ |
|
||||
| Discord | ✔ |
|
||||
| Satori | ✔ |
|
||||
| Misskey | ✔ |
|
||||
| Whatsapp | 将支持 |
|
||||
| LINE | 将支持 |
|
||||
- QQ (官方平台 & OneBot)
|
||||
- Telegram
|
||||
- 企微应用 & 企微智能机器人
|
||||
- 微信客服 & 微信公众号
|
||||
- 飞书
|
||||
- 钉钉
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- Whatsapp (将支持)
|
||||
- LINE (将支持)
|
||||
|
||||
**社区维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
|
||||
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
## ⚡ 提供商支持情况
|
||||
## 支持的模型服务
|
||||
|
||||
**大模型服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | |
|
||||
| Google Gemini | ✔ | |
|
||||
| Moonshot AI | ✔ | |
|
||||
| 智谱 AI | ✔ | |
|
||||
| DeepSeek | ✔ | |
|
||||
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
|
||||
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
|
||||
| 硅基流动 | ✔ | |
|
||||
| PPIO 派欧云 | ✔ | |
|
||||
| ModelScope | ✔ | |
|
||||
| OneAPI | ✔ | |
|
||||
| Dify | ✔ | |
|
||||
| 阿里云百炼应用 | ✔ | |
|
||||
| Coze | ✔ | |
|
||||
- OpenAI 及兼容服务
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智谱 AI
|
||||
- DeepSeek
|
||||
- Ollama (本地部署)
|
||||
- LM Studio (本地部署)
|
||||
- [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小马算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [硅基流动](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
**LLMOps 平台**
|
||||
|
||||
- Dify
|
||||
- 阿里云百炼应用
|
||||
- Coze
|
||||
|
||||
**语音转文本服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| Whisper | ✔ | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 本地部署 |
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
**文本转语音服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI TTS | ✔ | |
|
||||
| Gemini TTS | ✔ | |
|
||||
| GSVI | ✔ | GPT-Sovits-Inference |
|
||||
| GPT-SoVITs | ✔ | GPT-Sovits |
|
||||
| FishAudio | ✔ | |
|
||||
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | |
|
||||
| Azure TTS | ✔ | |
|
||||
| Minimax TTS | ✔ | |
|
||||
| 火山引擎 TTS | ✔ | |
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- 阿里云百炼 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- 火山引擎 TTS
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -229,7 +219,7 @@ pre-commit install
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
287
README_en.md
287
README_en.md
@@ -1,182 +1,233 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
|
||||
<br>
|
||||
|
||||
## ✨ Key Features
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
||||
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
||||
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
|
||||
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
||||
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
||||
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
||||
<br>
|
||||
|
||||
> [!TIP]
|
||||
> Dashboard Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
> Username: `astrbot`, Password: `astrbot` (LLM not configured for chat page)
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">Documentation</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
|
||||
</div>
|
||||
|
||||
## ✨ Deployment
|
||||
AstrBot is an open-source all-in-one Agent chatbot platform and development framework.
|
||||
|
||||
#### Docker Deployment
|
||||
## Key Features
|
||||
|
||||
See docs: [Deploy with Docker](https://astrbot.app/deploy/astrbot/docker.html#docker-deployment)
|
||||
1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more.
|
||||
2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation.
|
||||
3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more.
|
||||
4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem.
|
||||
5. **Web UI**. Visual configuration and management of your bot with comprehensive features.
|
||||
|
||||
#### Windows Installer
|
||||
## Deployment Methods
|
||||
|
||||
Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app/deploy/astrbot/windows.html)
|
||||
#### Docker Deployment (Recommended 🥳)
|
||||
|
||||
#### Replit Deployment
|
||||
We recommend deploying AstrBot using Docker or Docker Compose.
|
||||
|
||||
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
|
||||
|
||||
#### BT-Panel Deployment
|
||||
|
||||
AstrBot has partnered with BT-Panel and is now available in their marketplace.
|
||||
|
||||
Please refer to the official documentation: [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html).
|
||||
|
||||
#### 1Panel Deployment
|
||||
|
||||
AstrBot has been officially listed on the 1Panel marketplace.
|
||||
|
||||
Please refer to the official documentation: [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html).
|
||||
|
||||
#### Deploy on RainYun
|
||||
|
||||
AstrBot has been officially listed on RainYun's cloud application platform with one-click deployment.
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Deploy on Replit
|
||||
|
||||
Community-contributed deployment method.
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows One-Click Installer
|
||||
|
||||
Please refer to the official documentation: [Deploy AstrBot with Windows One-Click Installer](https://astrbot.app/deploy/astrbot/windows.html).
|
||||
|
||||
#### CasaOS Deployment
|
||||
|
||||
Community-contributed method.
|
||||
See docs: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html)
|
||||
Community-contributed deployment method.
|
||||
|
||||
Please refer to the official documentation: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html).
|
||||
|
||||
#### Manual Deployment
|
||||
|
||||
See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
|
||||
First, install uv:
|
||||
|
||||
## ⚡ Platform Support
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
| Platform | Status | Details | Message Types |
|
||||
| -------------------------------------------------------------- | ------ | ------------------- | ------------------- |
|
||||
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
|
||||
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
|
||||
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
|
||||
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
|
||||
| Feishu | ✔ | Group chats | Text, Images |
|
||||
| WeChat Open Platform | 🚧 | Planned | - |
|
||||
| Discord | 🚧 | Planned | - |
|
||||
| WhatsApp | 🚧 | Planned | - |
|
||||
| Xiaomi Speakers | 🚧 | Planned | - |
|
||||
Install AstrBot via Git Clone:
|
||||
|
||||
## Provider Support Status
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
| Name | Support | Type | Notes |
|
||||
|---------------------------|---------|------------------------|-----------------------------------------------------------------------|
|
||||
| OpenAI API | ✔ | Text Generation | Supports all OpenAI API-compatible services including DeepSeek, Google Gemini, GLM, Moonshot, Alibaba Cloud Bailian, Silicon Flow, xAI, etc. |
|
||||
| Claude API | ✔ | Text Generation | |
|
||||
| Google Gemini API | ✔ | Text Generation | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| DashScope (Alibaba Cloud) | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
|
||||
| LM Studio | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
|
||||
| LLMTuner | ✔ | Model Loader | Local loading of fine-tuned models (e.g. LoRA) |
|
||||
| OneAPI | ✔ | LLM Distribution | |
|
||||
| Whisper | ✔ | Speech-to-Text | Supports API and local deployment |
|
||||
| SenseVoice | ✔ | Speech-to-Text | Local deployment |
|
||||
| OpenAI TTS API | ✔ | Text-to-Speech | |
|
||||
| Fishaudio | ✔ | Text-to-Speech | Project involving GPT-Sovits author |
|
||||
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
# 🦌 Roadmap
|
||||
## 🌍 Community
|
||||
|
||||
> [!TIP]
|
||||
> Suggestions welcome via Issues <3
|
||||
### QQ Groups
|
||||
|
||||
- [ ] Ensure feature parity across all platform adapters
|
||||
- [ ] Optimize plugin APIs
|
||||
- [ ] Add default TTS services (e.g., GPT-Sovits)
|
||||
- [ ] Enhance chat features with persistent memory
|
||||
- [ ] i18n Planning
|
||||
- Group 1: 322154837
|
||||
- Group 3: 630166526
|
||||
- Group 5: 822130018
|
||||
- Group 6: 753075035
|
||||
- Developer Group: 975206796
|
||||
|
||||
## ❤️ Contributions
|
||||
### Telegram Group
|
||||
|
||||
All Issues/PRs welcome! Simply submit your changes to this project :)
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
For major features, please discuss via Issues first.
|
||||
### Discord Server
|
||||
|
||||
## 🌟 Support
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
- Star this project!
|
||||
- Support via [Afdian](https://afdian.com/a/soulter)
|
||||
- WeChat support: [QR Code](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)
|
||||
## Supported Messaging Platforms
|
||||
|
||||
## ✨ Demos
|
||||
**Officially Maintained**
|
||||
|
||||
> [!NOTE]
|
||||
> Code executor file I/O currently tested with Napcat(QQ)/Lagrange(QQ)
|
||||
- QQ (Official Platform & OneBot)
|
||||
- Telegram
|
||||
- WeChat Work Application & WeChat Work Intelligent Bot
|
||||
- WeChat Customer Service & WeChat Official Accounts
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (Coming Soon)
|
||||
- LINE (Coming Soon)
|
||||
|
||||
<div align='center'>
|
||||
**Community Maintained**
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
_✨ Docker-based Sandboxed Code Executor (Beta) ✨_
|
||||
## Supported Model Services
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
**LLM Services**
|
||||
|
||||
_✨ Multimodal Input, Web Search, Text-to-Image ✨_
|
||||
- OpenAI and Compatible Services
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- Zhipu AI
|
||||
- DeepSeek
|
||||
- Ollama (Self-hosted)
|
||||
- LM Studio (Self-hosted)
|
||||
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [TokenPony](https://www.tokenpony.cn/3YPyf)
|
||||
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
**LLMOps Platforms**
|
||||
|
||||
_✨ Natural Language TODO Lists ✨_
|
||||
- Dify
|
||||
- Alibaba Cloud Bailian Applications
|
||||
- Coze
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
**Speech-to-Text Services**
|
||||
|
||||
_✨ Plugin System Showcase ✨_
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
||||
**Text-to-Speech Services**
|
||||
|
||||
_✨ Web Dashboard ✨_
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud Bailian TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||

|
||||
## ❤️ Contributing
|
||||
|
||||
_✨ Built-in Web Chat Interface ✨_
|
||||
Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
|
||||
|
||||
</div>
|
||||
### How to Contribute
|
||||
|
||||
You can contribute by reviewing issues or helping with pull request reviews. Any issues or PRs are welcome to encourage community participation. Of course, these are just suggestions—you can contribute in any way you like. For adding new features, please discuss through an Issue first.
|
||||
|
||||
### Development Environment
|
||||
|
||||
AstrBot uses `ruff` for code formatting and linting.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
Additionally, the birth of this project would not have been possible without the help of the following open-source projects:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - The amazing cat framework
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> If this project helps you, please give it a star <3
|
||||
> [!TIP]
|
||||
> If this project has helped you in your life or work, or if you're interested in its future development, please give the project a Star. It's the driving force behind maintaining this open-source project <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#AstrBotDevs/AstrBot&Date)
|
||||
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. Licensed under `AGPL-v3`.
|
||||
2. WeChat integration uses [Gewechat](https://github.com/Devo919/Gewechat). Use at your own risk with non-critical accounts.
|
||||
3. Users must comply with local laws and regulations.
|
||||
|
||||
<!-- ## ✨ ATRI [Beta]
|
||||
|
||||
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri)
|
||||
|
||||
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
|
||||
2. Long-term memory
|
||||
3. Meme understanding & responses
|
||||
4. TTS integration
|
||||
-->
|
||||
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
|
||||
270
README_ja.md
270
README_ja.md
@@ -1,167 +1,233 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||

|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_✨ 簡単に使えるマルチプラットフォーム LLM チャットボットおよび開発フレームワーク ✨_
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/AstrBotDevs/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
[](https://codecov.io/gh/AstrBotDevs/AstrBot)
|
||||
|
||||
<a href="https://astrbot.app/">ドキュメントを見る</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデル(LLM)接続機能を備えたチャットボットおよび開発フレームワークです。
|
||||
<br>
|
||||
|
||||
## ✨ 主な機能
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
||||
<br>
|
||||
|
||||
> [!TIP]
|
||||
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://astrbot.app/">ドキュメント</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a> |
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue</a>
|
||||
</div>
|
||||
|
||||
## ✨ 使用方法
|
||||
AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークです。
|
||||
|
||||
#### Docker デプロイ
|
||||
## 主な機能
|
||||
|
||||
公式ドキュメント [Docker を使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) を参照してください。
|
||||
1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載。
|
||||
2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート。
|
||||
3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。
|
||||
4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム。
|
||||
5. **WebUI**。ビジュアル設定とボット管理、充実した機能。
|
||||
|
||||
#### Windows ワンクリックインストーラーのデプロイ
|
||||
## デプロイ方法
|
||||
|
||||
コンピュータに Python(>3.10)がインストールされている必要があります。公式ドキュメント [Windows ワンクリックインストーラーを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/windows.html) を参照してください。
|
||||
#### Docker デプロイ(推奨 🥳)
|
||||
|
||||
#### Replit デプロイ
|
||||
Docker / Docker Compose を使用した AstrBot のデプロイを推奨します。
|
||||
|
||||
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
|
||||
|
||||
#### 宝塔パネルデプロイ
|
||||
|
||||
AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
|
||||
|
||||
公式ドキュメント [宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html) をご参照ください。
|
||||
|
||||
#### 1Panel デプロイ
|
||||
|
||||
AstrBot は 1Panel 公式により 1Panel パネルに公開されています。
|
||||
|
||||
公式ドキュメント [1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html) をご参照ください。
|
||||
|
||||
#### 雨云でのデプロイ
|
||||
|
||||
AstrBot は雨云公式によりクラウドアプリケーションプラットフォームに公開され、ワンクリックでデプロイ可能です。
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
#### Replit でのデプロイ
|
||||
|
||||
コミュニティ貢献によるデプロイ方法。
|
||||
|
||||
[](https://repl.it/github/AstrBotDevs/AstrBot)
|
||||
|
||||
#### Windows ワンクリックインストーラーデプロイ
|
||||
|
||||
公式ドキュメント [Windows ワンクリックインストーラーを使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/windows.html) をご参照ください。
|
||||
|
||||
#### CasaOS デプロイ
|
||||
|
||||
コミュニティが提供するデプロイ方法です。
|
||||
コミュニティ貢献によるデプロイ方法。
|
||||
|
||||
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/casaos.html) を参照してください。
|
||||
公式ドキュメント [CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html) をご参照ください。
|
||||
|
||||
#### 手動デプロイ
|
||||
|
||||
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/cli.html) を参照してください。
|
||||
まず uv をインストールします:
|
||||
|
||||
## ⚡ メッセージプラットフォームのサポート状況
|
||||
```bash
|
||||
pip install uv
|
||||
```
|
||||
|
||||
| プラットフォーム | サポート状況 | 詳細 | メッセージタイプ |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ(公式ロボットインターフェース) | ✔ | プライベートチャット、グループチャット、QQ チャンネルプライベートチャット、グループチャット | テキスト、画像 |
|
||||
| QQ(OneBot) | ✔ | プライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||
| WeChat(個人アカウント) | ✔ | WeChat 個人アカウントのプライベートチャット、グループチャット | テキスト、画像、音声 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | プライベートチャット、グループチャット | テキスト、画像 |
|
||||
| [WeChat(企業 WeChat)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | プライベートチャット | テキスト、画像、音声 |
|
||||
| Feishu | ✔ | グループチャット | テキスト、画像 |
|
||||
| WeChat 対話オープンプラットフォーム | 🚧 | 計画中 | - |
|
||||
| Discord | 🚧 | 計画中 | - |
|
||||
| WhatsApp | 🚧 | 計画中 | - |
|
||||
| Xiaoai 音響 | 🚧 | 計画中 | - |
|
||||
Git Clone で AstrBot をインストール:
|
||||
|
||||
# 🦌 今後のロードマップ
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Issue でさらに多くの提案を歓迎します <3
|
||||
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
|
||||
|
||||
- [ ] 現在のすべてのプラットフォームアダプターの機能の一貫性を確保し、改善する
|
||||
- [ ] プラグインインターフェースの最適化
|
||||
- [ ] GPT-Sovits などの TTS サービスをデフォルトでサポート
|
||||
- [ ] "チャット強化" 部分を完成させ、永続的な記憶をサポート
|
||||
- [ ] i18n の計画
|
||||
## 🌍 コミュニティ
|
||||
|
||||
## ❤️ 貢献
|
||||
### QQ グループ
|
||||
|
||||
Issue や Pull Request を歓迎します!このプロジェクトに変更を加えるだけです :)
|
||||
- 1群:322154837
|
||||
- 3群:630166526
|
||||
- 5群:822130018
|
||||
- 6群:753075035
|
||||
- 開発者群:975206796
|
||||
|
||||
新機能の追加については、まず Issue で議論してください。
|
||||
### Telegram グループ
|
||||
|
||||
## 🌟 サポート
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
- このプロジェクトに Star を付けてください!
|
||||
- [愛発電](https://afdian.com/a/soulter)で私をサポートしてください!
|
||||
- [WeChat](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)で私をサポートしてください~
|
||||
### Discord サーバー
|
||||
|
||||
## ✨ デモ
|
||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
|
||||
> [!NOTE]
|
||||
> コードエグゼキューターのファイル入力/出力は現在 Napcat(QQ)、Lagrange(QQ) でのみテストされています
|
||||
## サポートされているメッセージプラットフォーム
|
||||
|
||||
<div align='center'>
|
||||
**公式メンテナンス**
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
- QQ (公式プラットフォーム & OneBot)
|
||||
- Telegram
|
||||
- WeChat Work アプリケーション & WeChat Work インテリジェントボット
|
||||
- WeChat カスタマーサービス & WeChat 公式アカウント
|
||||
- Feishu (Lark)
|
||||
- DingTalk
|
||||
- Slack
|
||||
- Discord
|
||||
- Satori
|
||||
- Misskey
|
||||
- WhatsApp (近日対応予定)
|
||||
- LINE (近日対応予定)
|
||||
|
||||
_✨ Docker ベースのサンドボックス化されたコードエグゼキューター(ベータテスト中)✨_
|
||||
**コミュニティメンテナンス**
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
|
||||
|
||||
_✨ 多モーダル、ウェブ検索、長文の画像変換(設定可能)✨_
|
||||
## サポートされているモデルサービス
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
**大規模言語モデルサービス**
|
||||
|
||||
_✨ 自然言語タスク ✨_
|
||||
- OpenAI および互換サービス
|
||||
- Anthropic
|
||||
- Google Gemini
|
||||
- Moonshot AI
|
||||
- 智谱 AI
|
||||
- DeepSeek
|
||||
- Ollama (セルフホスト)
|
||||
- LM Studio (セルフホスト)
|
||||
- [優云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
|
||||
- [302.AI](https://share.302.ai/rr1M3l)
|
||||
- [小馬算力](https://www.tokenpony.cn/3YPyf)
|
||||
- [硅基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
|
||||
- [PPIO 派欧云](https://ppio.com/user/register?invited_by=AIOONE)
|
||||
- ModelScope
|
||||
- OneAPI
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
**LLMOps プラットフォーム**
|
||||
|
||||
_✨ プラグインシステム - 一部のプラグインの展示 ✨_
|
||||
- Dify
|
||||
- Alibaba Cloud 百炼アプリケーション
|
||||
- Coze
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width="600">
|
||||
**音声認識サービス**
|
||||
|
||||
_✨ 管理パネル ✨_
|
||||
- OpenAI Whisper
|
||||
- SenseVoice
|
||||
|
||||

|
||||
**音声合成サービス**
|
||||
|
||||
_✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
|
||||
- OpenAI TTS
|
||||
- Gemini TTS
|
||||
- GPT-Sovits-Inference
|
||||
- GPT-Sovits
|
||||
- FishAudio
|
||||
- Edge TTS
|
||||
- Alibaba Cloud 百炼 TTS
|
||||
- Azure TTS
|
||||
- Minimax TTS
|
||||
- Volcano Engine TTS
|
||||
|
||||
</div>
|
||||
## ❤️ コントリビューション
|
||||
|
||||
Issue や Pull Request は大歓迎です!このプロジェクトに変更を送信してください :)
|
||||
|
||||
### コントリビュート方法
|
||||
|
||||
Issue を確認したり、PR(プルリクエスト)のレビューを手伝うことで貢献できます。どんな Issue や PR への参加も歓迎され、コミュニティ貢献を促進します。もちろん、これらは提案に過ぎず、どんな方法でも貢献できます。新機能の追加については、まず Issue で議論してください。
|
||||
|
||||
### 開発環境
|
||||
|
||||
AstrBot はコードのフォーマットとチェックに `ruff` を使用しています。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
また、このプロジェクトの誕生は以下のオープンソースプロジェクトの助けなしには実現できませんでした:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 素晴らしい猫猫フレームワーク
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
|
||||
> このプロジェクトがあなたの生活や仕事に役立ったり、このプロジェクトの今後の発展に関心がある場合は、プロジェクトに Star をください。これがこのオープンソースプロジェクトを維持する原動力です <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
[](https://star-history.com/#astrbotdevs/astrbot&Date)
|
||||
|
||||
</div>
|
||||
|
||||
## スポンサー
|
||||
|
||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||
|
||||
## 免責事項
|
||||
|
||||
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
|
||||
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
|
||||
|
||||
<!-- ## ✨ ATRI [ベータテスト]
|
||||
|
||||
この機能はプラグインとしてロードされます。プラグインリポジトリのアドレス:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||
|
||||
1. 《ATRI ~ My Dear Moments》の主人公 ATRI のキャラクターセリフを微調整データセットとして使用した `Qwen1.5-7B-Chat Lora` 微調整モデル。
|
||||
2. 長期記憶
|
||||
3. ミームの理解と返信
|
||||
4. TTS
|
||||
-->
|
||||
</details>
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
@@ -36,7 +36,8 @@ from astrbot.core.star.config import *
|
||||
|
||||
|
||||
# provider
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
from astrbot.core.provider import Provider, ProviderMetaData
|
||||
from astrbot.core.db.po import Personality
|
||||
|
||||
# platform
|
||||
from astrbot.core.platform import (
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from astrbot.core.provider import Personality, Provider, STTProvider
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider import Provider, STTProvider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderMetaData,
|
||||
|
||||
@@ -4,6 +4,14 @@ from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from typing import Generic
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
@@ -12,21 +20,24 @@ from .run_context import TContext
|
||||
from .tool import FunctionTool
|
||||
|
||||
try:
|
||||
import anyio
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
logger.warning(
|
||||
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
||||
)
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。",
|
||||
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
||||
)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
"""Prepare configuration, handle nested format"""
|
||||
if config.get("mcpServers"):
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
@@ -35,7 +46,7 @@ def _prepare_config(config: dict) -> dict:
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
"""Quick test MCP server connectivity"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
@@ -50,7 +61,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if transport_type == "streamable_http":
|
||||
@@ -91,7 +102,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
return False, f"Connection timeout: {timeout} seconds"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
@@ -101,6 +112,7 @@ class MCPClient:
|
||||
# Initialize session and client objects
|
||||
self.session: mcp.ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
||||
|
||||
self.name: str | None = None
|
||||
self.active: bool = True
|
||||
@@ -108,22 +120,32 @@ class MCPClient:
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
# Store connection config for reconnection
|
||||
self._mcp_server_config: dict | None = None
|
||||
self._server_name: str | None = None
|
||||
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
||||
self._reconnecting: bool = False # For logging and debugging
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""Connect to MCP server
|
||||
|
||||
If `url` parameter exists:
|
||||
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
||||
2. When transport is specified as `sse`, use SSE connection.
|
||||
3. If not specified, default to SSE connection to MCP service.
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
|
||||
"""
|
||||
# Store config for reconnection
|
||||
self._mcp_server_config = mcp_server_config
|
||||
self._server_name = name
|
||||
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
# Handle MCP service error logs
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
@@ -137,7 +159,7 @@ class MCPClient:
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
raise Exception("MCP connection config missing transport or type field")
|
||||
|
||||
if transport_type != "streamable_http":
|
||||
# SSE transport method
|
||||
@@ -193,7 +215,7 @@ class MCPClient:
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
# Handle MCP service error logs
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
@@ -222,10 +244,120 @@ class MCPClient:
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Reconnect to the MCP server using the stored configuration.
|
||||
|
||||
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
||||
|
||||
Raises:
|
||||
Exception: raised when reconnection fails
|
||||
"""
|
||||
async with self._reconnect_lock:
|
||||
# Check if already reconnecting (useful for logging)
|
||||
if self._reconnecting:
|
||||
logger.debug(
|
||||
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
if not self._mcp_server_config or not self._server_name:
|
||||
raise Exception("Cannot reconnect: missing connection configuration")
|
||||
|
||||
self._reconnecting = True
|
||||
try:
|
||||
logger.info(
|
||||
f"Attempting to reconnect to MCP server {self._server_name}..."
|
||||
)
|
||||
|
||||
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
||||
if self.exit_stack:
|
||||
self._old_exit_stacks.append(self.exit_stack)
|
||||
|
||||
# Mark old session as invalid
|
||||
self.session = None
|
||||
|
||||
# Create new exit stack for new connection
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
# Reconnect using stored config
|
||||
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
||||
await self.list_tools_and_save()
|
||||
|
||||
logger.info(
|
||||
f"Successfully reconnected to MCP server {self._server_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
self._reconnecting = False
|
||||
|
||||
async def call_tool_with_reconnect(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict,
|
||||
read_timeout_seconds: timedelta,
|
||||
) -> mcp.types.CallToolResult:
|
||||
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
||||
|
||||
Args:
|
||||
tool_name: tool name
|
||||
arguments: tool arguments
|
||||
read_timeout_seconds: read timeout
|
||||
|
||||
Returns:
|
||||
MCP tool call result
|
||||
|
||||
Raises:
|
||||
ValueError: MCP session is not available
|
||||
anyio.ClosedResourceError: raised after reconnection failure
|
||||
"""
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
||||
stop=stop_after_attempt(2),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
reraise=True,
|
||||
)
|
||||
async def _call_with_retry():
|
||||
if not self.session:
|
||||
raise ValueError("MCP session is not available for MCP function tools.")
|
||||
|
||||
try:
|
||||
return await self.session.call_tool(
|
||||
name=tool_name,
|
||||
arguments=arguments,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
logger.warning(
|
||||
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
||||
)
|
||||
# Attempt to reconnect
|
||||
await self._reconnect()
|
||||
# Reraise the exception to trigger tenacity retry
|
||||
raise
|
||||
|
||||
return await _call_with_retry()
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Set running_event first to unblock any waiting tasks
|
||||
self.running_event.set()
|
||||
|
||||
# Close current exit stack
|
||||
try:
|
||||
await self.exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing current exit stack: {e}")
|
||||
|
||||
# Don't close old exit stacks as they may be in different task contexts
|
||||
# They will be garbage collected naturally
|
||||
# Just clear the list to release references
|
||||
self._old_exit_stacks.clear()
|
||||
|
||||
|
||||
class MCPTool(FunctionTool, Generic[TContext]):
|
||||
@@ -246,14 +378,8 @@ class MCPTool(FunctionTool, Generic[TContext]):
|
||||
async def call(
|
||||
self, context: ContextWrapper[TContext], **kwargs
|
||||
) -> mcp.types.CallToolResult:
|
||||
session = self.mcp_client.session
|
||||
if not session:
|
||||
raise ValueError("MCP session is not available for MCP function tools.")
|
||||
res = await session.call_tool(
|
||||
name=self.mcp_tool.name,
|
||||
return await self.mcp_client.call_tool_with_reconnect(
|
||||
tool_name=self.mcp_tool.name,
|
||||
arguments=kwargs,
|
||||
read_timeout_seconds=timedelta(
|
||||
seconds=context.tool_call_timeout,
|
||||
),
|
||||
read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
|
||||
)
|
||||
return res
|
||||
|
||||
@@ -76,7 +76,7 @@ class ImageURLPart(ContentPart):
|
||||
"""The ID of the image, to allow LLMs to distinguish different images."""
|
||||
|
||||
type: str = "image_url"
|
||||
image_url: str
|
||||
image_url: ImageURL
|
||||
|
||||
|
||||
class AudioURLPart(ContentPart):
|
||||
@@ -119,6 +119,13 @@ class ToolCall(BaseModel):
|
||||
"""The ID of the tool call."""
|
||||
function: FunctionBody
|
||||
"""The function body of the tool call."""
|
||||
extra_content: dict[str, Any] | None = None
|
||||
"""Extra metadata for the tool call."""
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
if self.extra_content is None:
|
||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||
return super().model_dump(**kwargs)
|
||||
|
||||
|
||||
class ToolCallPart(BaseModel):
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from .message import Message
|
||||
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
context: TContext
|
||||
messages: list[Message] = Field(default_factory=list)
|
||||
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,13 @@ class BaseAgentRunner(T.Generic[TContext]):
|
||||
"""Process a single step of the agent."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process steps until the agent is done."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""Check if the agent has completed its task.
|
||||
|
||||
@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, ToolCallMessageSegment
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
@@ -55,6 +55,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
messages = []
|
||||
# append existing messages in the run context
|
||||
for msg in request.contexts:
|
||||
messages.append(Message.model_validate(msg))
|
||||
if request.prompt is not None:
|
||||
m = await request.assemble_context()
|
||||
messages.append(Message.model_validate(m))
|
||||
if request.system_prompt:
|
||||
messages.insert(
|
||||
0,
|
||||
Message(role="system", content=request.system_prompt),
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
@@ -96,13 +110,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
elif llm_response.completion_text:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text),
|
||||
),
|
||||
)
|
||||
elif llm_response.reasoning_content:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="reasoning").message(
|
||||
llm_response.reasoning_content,
|
||||
),
|
||||
),
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
break # got final response
|
||||
@@ -130,6 +153,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
),
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -156,13 +186,16 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
@@ -175,8 +208,23 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
# record the assistant message with tool calls
|
||||
self.run_context.messages.extend(
|
||||
tool_calls_result.to_openai_messages_model()
|
||||
)
|
||||
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def step_until_done(
|
||||
self, max_step: int
|
||||
) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""Process steps until the agent is done."""
|
||||
step_count = 0
|
||||
while not self.done() and step_count < max_step:
|
||||
step_count += 1
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
|
||||
@@ -4,12 +4,13 @@ from typing import Any, Generic
|
||||
import jsonschema
|
||||
import mcp
|
||||
from deprecated import deprecated
|
||||
from pydantic import model_validator
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
ToolExecResult = str | mcp.types.CallToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -55,15 +56,14 @@ class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[TContext], **kwargs
|
||||
) -> str | mcp.types.CallToolResult:
|
||||
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
|
||||
"""Run the tool with the given arguments. The handler field has priority."""
|
||||
raise NotImplementedError(
|
||||
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSet:
|
||||
"""A set of function tools that can be used in function calling.
|
||||
|
||||
@@ -71,8 +71,7 @@ class ToolSet:
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
|
||||
"""
|
||||
|
||||
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
tools: list[FunctionTool] = Field(default_factory=list)
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
class AstrAgentContext:
|
||||
provider: Provider
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
"""The message event associated with the agent context."""
|
||||
extra: dict[str, str] = Field(default_factory=dict)
|
||||
"""Customized extra data."""
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
|
||||
36
astrbot/core/astr_agent_hooks.py
Normal file
36
astrbot/core/astr_agent_hooks.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.pipeline.context_utils import call_event_hook
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
llm_response,
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
tool: FunctionTool[Any],
|
||||
tool_args: dict | None,
|
||||
tool_result: CallToolResult | None,
|
||||
):
|
||||
run_context.context.event.clear_result()
|
||||
|
||||
|
||||
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
pass
|
||||
|
||||
|
||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||
80
astrbot/core/astr_agent_run_util.py
Normal file
80
astrbot/core/astr_agent_run_util.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
stream_to_general: bool = False,
|
||||
show_reasoning: bool = False,
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use:
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
continue
|
||||
|
||||
if stream_to_general or not agent_runner.streaming:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
astr_event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
),
|
||||
)
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
elif resp.type == "streaming_delta":
|
||||
chain = resp.data["chain"]
|
||||
if chain.type == "reasoning" and not show_reasoning:
|
||||
# display the reasoning content only when configured
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
246
astrbot/core/astr_agent_tool_exec.py
Normal file
246
astrbot/core/astr_agent_tool_exec.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
import mcp
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import (
|
||||
CommandResult,
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
)
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
async def execute(cls, tool, run_context, **tool_args):
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
|
||||
"""
|
||||
if isinstance(tool, HandoffTool):
|
||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
elif isinstance(tool, MCPTool):
|
||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
else:
|
||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
tool: HandoffTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input")
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
ctx = run_context.context.context
|
||||
event = run_context.context.event
|
||||
umo = event.unified_msg_origin
|
||||
prov_id = await ctx.get_current_chat_provider_id(umo)
|
||||
llm_resp = await ctx.tool_loop_agent(
|
||||
event=event,
|
||||
chat_provider_id=prov_id,
|
||||
prompt=input_,
|
||||
system_prompt=tool.agent.instructions,
|
||||
tools=toolset,
|
||||
max_steps=30,
|
||||
run_hooks=tool.agent.run_hooks,
|
||||
)
|
||||
yield mcp.types.CallToolResult(
|
||||
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
event = run_context.context.event
|
||||
if not event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
|
||||
is_override_call = False
|
||||
for ty in type(tool).mro():
|
||||
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
||||
is_override_call = True
|
||||
break
|
||||
|
||||
# 检查 tool 下有没有 run 方法
|
||||
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
awaitable = None
|
||||
method_name = ""
|
||||
if tool.handler:
|
||||
awaitable = tool.handler
|
||||
method_name = "decorator_handler"
|
||||
elif is_override_call:
|
||||
awaitable = tool.call
|
||||
method_name = "call"
|
||||
elif hasattr(tool, "run"):
|
||||
awaitable = getattr(tool, "run")
|
||||
method_name = "run"
|
||||
if awaitable is None:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
wrapper = call_local_llm_tool(
|
||||
context=run_context,
|
||||
handler=awaitable,
|
||||
method_name=method_name,
|
||||
**tool_args,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
anext(wrapper),
|
||||
timeout=run_context.tool_call_timeout,
|
||||
)
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
if res := run_context.context.event.get_result():
|
||||
if res.chain:
|
||||
try:
|
||||
await event.send(
|
||||
MessageChain(
|
||||
chain=res.chain,
|
||||
type="tool_direct_result",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Tool 直接发送消息失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield None
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
res = await tool.call(run_context, **tool_args)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
event = context.context.event
|
||||
|
||||
try:
|
||||
if method_name == "run" or method_name == "decorator_handler":
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
elif method_name == "call":
|
||||
ready_to_call = handler(context, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.5.2"
|
||||
VERSION = "4.6.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -68,7 +68,7 @@ DEFAULT_CONFIG = {
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"unsupported_streaming_strategy": "realtime_segmenting",
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
},
|
||||
@@ -137,6 +137,7 @@ DEFAULT_CONFIG = {
|
||||
"kb_names": [], # 默认知识库名称列表
|
||||
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||
"kb_agentic_mode": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -740,6 +741,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
@@ -755,6 +757,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -768,6 +771,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"xai_native_search": False,
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
@@ -799,6 +803,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -813,6 +818,7 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -829,6 +835,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -870,6 +877,24 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
@@ -883,6 +908,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -899,6 +925,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -915,6 +942,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
@@ -930,6 +958,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
@@ -944,6 +973,7 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -957,6 +987,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -972,6 +1003,8 @@ CONFIG_METADATA_2 = {
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Dify": {
|
||||
@@ -1028,6 +1061,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
@@ -1040,6 +1074,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
"timeout": 60,
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"Whisper(API)": {
|
||||
@@ -1321,6 +1356,12 @@ CONFIG_METADATA_2 = {
|
||||
"render_type": "checkbox",
|
||||
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
||||
},
|
||||
"custom_headers": {
|
||||
"description": "自定义添加请求头",
|
||||
"type": "dict",
|
||||
"items": {},
|
||||
"hint": "此处添加的键值对将被合并到 OpenAI SDK 的 default_headers 中,用于自定义 HTTP 请求头。值必须为字符串。",
|
||||
},
|
||||
"custom_extra_body": {
|
||||
"description": "自定义请求体参数",
|
||||
"type": "dict",
|
||||
@@ -1970,8 +2011,8 @@ CONFIG_METADATA_2 = {
|
||||
"show_tool_use_status": {
|
||||
"type": "bool",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"type": "bool",
|
||||
"unsupported_streaming_strategy": {
|
||||
"type": "string",
|
||||
},
|
||||
"max_agent_step": {
|
||||
"description": "工具调用轮数上限",
|
||||
@@ -2106,6 +2147,7 @@ CONFIG_METADATA_2 = {
|
||||
"kb_names": {"type": "list", "items": {"type": "string"}},
|
||||
"kb_fusion_top_k": {"type": "int", "default": 20},
|
||||
"kb_final_top_k": {"type": "int", "default": 5},
|
||||
"kb_agentic_mode": {"type": "bool"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -2201,6 +2243,11 @@ CONFIG_METADATA_3 = {
|
||||
"type": "int",
|
||||
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
|
||||
},
|
||||
"kb_agentic_mode": {
|
||||
"description": "Agentic 知识库检索",
|
||||
"type": "bool",
|
||||
"hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"websearch": {
|
||||
@@ -2276,9 +2323,15 @@ CONFIG_METADATA_3 = {
|
||||
"description": "流式回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"provider_settings.streaming_segmented": {
|
||||
"description": "不支持流式回复的平台采取分段输出",
|
||||
"type": "bool",
|
||||
"provider_settings.unsupported_streaming_strategy": {
|
||||
"description": "不支持流式回复的平台",
|
||||
"type": "string",
|
||||
"options": ["realtime_segmenting", "turn_off"],
|
||||
"hint": "选择在不支持流式回复的平台上的处理方式。实时分段回复会在系统接收流式响应检测到诸如标点符号等分段点时,立即发送当前已接收的内容",
|
||||
"labels": ["实时分段回复", "关闭流式回复"],
|
||||
"condition": {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
|
||||
@@ -22,7 +22,9 @@ from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.memory.memory_manager import MemoryManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
@@ -103,6 +105,13 @@ class AstrBotCoreLifecycle:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for webchat session
|
||||
try:
|
||||
await migrate_webchat_session(self.db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
@@ -128,6 +137,8 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
# 初始化记忆管理器
|
||||
self.memory_manager = MemoryManager()
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
@@ -141,6 +152,7 @@ class AstrBotCoreLifecycle:
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
self.kb_manager,
|
||||
self.memory_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
|
||||
@@ -13,6 +13,7 @@ from astrbot.core.db.po import (
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
Stats,
|
||||
@@ -183,7 +184,7 @@ class BaseDatabase(abc.ABC):
|
||||
user_id: str,
|
||||
offset_sec: int = 86400,
|
||||
) -> None:
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
"""Delete platform message history records newer than the specified offset."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -313,3 +314,51 @@ class BaseDatabase(abc.ABC):
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
|
||||
...
|
||||
|
||||
# ====
|
||||
# Platform Session Management
|
||||
# ====
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_platform_session(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str = "webchat",
|
||||
session_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
is_group: int = 0,
|
||||
) -> PlatformSession:
|
||||
"""Create a new Platform session."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_session_by_id(
|
||||
self, session_id: str
|
||||
) -> PlatformSession | None:
|
||||
"""Get a Platform session by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_sessions_by_creator(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformSession]:
|
||||
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_platform_session(
|
||||
self,
|
||||
session_id: str,
|
||||
display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_platform_session(self, session_id: str) -> None:
|
||||
"""Delete a Platform session by its ID."""
|
||||
...
|
||||
|
||||
131
astrbot/core/db/migration/migra_webchat_session.py
Normal file
131
astrbot/core/db/migration/migra_webchat_session.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Migration script for WebChat sessions.
|
||||
|
||||
This migration creates PlatformSession from existing platform_message_history records.
|
||||
|
||||
Changes:
|
||||
- Creates platform_sessions table
|
||||
- Adds platform_id field (default: 'webchat')
|
||||
- Adds display_name field
|
||||
- Session_id format: {platform_id}_{uuid}
|
||||
"""
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlmodel import col
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
|
||||
|
||||
|
||||
async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
"""Create PlatformSession records from platform_message_history.
|
||||
|
||||
This migration extracts all unique user_ids from platform_message_history
|
||||
where platform_id='webchat' and creates corresponding PlatformSession records.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_webchat_session"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(WebChat 会话迁移)...")
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 从 platform_message_history 创建 PlatformSession
|
||||
query = (
|
||||
select(
|
||||
col(PlatformMessageHistory.user_id),
|
||||
col(PlatformMessageHistory.sender_name),
|
||||
func.min(PlatformMessageHistory.created_at).label("earliest"),
|
||||
func.max(PlatformMessageHistory.updated_at).label("latest"),
|
||||
)
|
||||
.where(col(PlatformMessageHistory.platform_id) == "webchat")
|
||||
.where(col(PlatformMessageHistory.sender_id) == "astrbot")
|
||||
.group_by(col(PlatformMessageHistory.user_id))
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
webchat_users = result.all()
|
||||
|
||||
if not webchat_users:
|
||||
logger.info("没有找到需要迁移的 WebChat 数据")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_webchat_session", True
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"找到 {len(webchat_users)} 个 WebChat 会话需要迁移")
|
||||
|
||||
# 检查已存在的会话
|
||||
existing_query = select(col(PlatformSession.session_id))
|
||||
existing_result = await session.execute(existing_query)
|
||||
existing_session_ids = {row[0] for row in existing_result.fetchall()}
|
||||
|
||||
# 查询 Conversations 表中的 title,用于设置 display_name
|
||||
# 对于每个 user_id,对应的 conversation user_id 格式为: webchat:FriendMessage:webchat!astrbot!{user_id}
|
||||
user_ids_to_query = [
|
||||
f"webchat:FriendMessage:webchat!astrbot!{user_id}"
|
||||
for user_id, _, _, _ in webchat_users
|
||||
]
|
||||
conv_query = select(
|
||||
col(ConversationV2.user_id), col(ConversationV2.title)
|
||||
).where(col(ConversationV2.user_id).in_(user_ids_to_query))
|
||||
conv_result = await session.execute(conv_query)
|
||||
# 创建 user_id -> title 的映射字典
|
||||
title_map = {
|
||||
user_id.replace("webchat:FriendMessage:webchat!astrbot!", ""): title
|
||||
for user_id, title in conv_result.fetchall()
|
||||
}
|
||||
|
||||
# 批量创建 PlatformSession 记录
|
||||
sessions_to_add = []
|
||||
skipped_count = 0
|
||||
|
||||
for user_id, sender_name, created_at, updated_at in webchat_users:
|
||||
# user_id 就是 webchat_conv_id (session_id)
|
||||
session_id = user_id
|
||||
|
||||
# sender_name 通常是 username,但可能为 None
|
||||
creator = sender_name if sender_name else "guest"
|
||||
|
||||
# 检查是否已经存在该会话
|
||||
if session_id in existing_session_ids:
|
||||
logger.debug(f"会话 {session_id} 已存在,跳过")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# 从 Conversations 表中获取 display_name
|
||||
display_name = title_map.get(user_id)
|
||||
|
||||
# 创建新的 PlatformSession(保留原有的时间戳)
|
||||
new_session = PlatformSession(
|
||||
session_id=session_id,
|
||||
platform_id="webchat",
|
||||
creator=creator,
|
||||
is_group=0,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
display_name=display_name,
|
||||
)
|
||||
sessions_to_add.append(new_session)
|
||||
|
||||
# 批量插入
|
||||
if sessions_to_add:
|
||||
session.add_all(sessions_to_add)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"WebChat 会话迁移完成!成功迁移: {len(sessions_to_add)}, 跳过: {skipped_count}",
|
||||
)
|
||||
else:
|
||||
logger.info("没有新会话需要迁移")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_webchat_session", True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -3,13 +3,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlmodel import (
|
||||
JSON,
|
||||
Field,
|
||||
SQLModel,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint
|
||||
|
||||
|
||||
class PlatformStat(SQLModel, table=True):
|
||||
@@ -18,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats"
|
||||
__tablename__ = "platform_stats" # type: ignore
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
@@ -37,7 +31,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations"
|
||||
__tablename__ = "conversations" # type: ignore
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
primary_key=True,
|
||||
@@ -74,7 +68,7 @@ class Persona(SQLModel, table=True):
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas"
|
||||
__tablename__ = "personas" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -104,7 +98,7 @@ class Persona(SQLModel, table=True):
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences"
|
||||
__tablename__ = "preferences" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
@@ -140,7 +134,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
__tablename__ = "platform_message_history" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -161,13 +155,55 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class PlatformSession(SQLModel, table=True):
|
||||
"""Platform session table for managing user sessions across different platforms.
|
||||
|
||||
A session represents a chat window for a specific user on a specific platform.
|
||||
Each session can have multiple conversations (对话) associated with it.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_sessions" # type: ignore
|
||||
|
||||
inner_id: int | None = Field(
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
session_id: str = Field(
|
||||
max_length=100,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: f"webchat_{uuid.uuid4()}",
|
||||
)
|
||||
platform_id: str = Field(default="webchat", nullable=False)
|
||||
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
|
||||
creator: str = Field(nullable=False)
|
||||
"""Username of the session creator"""
|
||||
display_name: str | None = Field(default=None, max_length=255)
|
||||
"""Display name for the session"""
|
||||
is_group: int = Field(default=0, nullable=False)
|
||||
"""0 for private chat, 1 for group chat (not implemented yet)"""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"session_id",
|
||||
name="uix_platform_session_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Attachment(SQLModel, table=True):
|
||||
"""This class represents attachments for messages in AstrBot.
|
||||
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments"
|
||||
__tablename__ = "attachments" # type: ignore
|
||||
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import typing as T
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
@@ -12,6 +12,7 @@ from astrbot.core.db.po import (
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
SQLModel,
|
||||
@@ -412,7 +413,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
user_id,
|
||||
offset_sec=86400,
|
||||
):
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
"""Delete platform message history records newer than the specified offset."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -422,7 +423,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
delete(PlatformMessageHistory).where(
|
||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||
col(PlatformMessageHistory.user_id) == user_id,
|
||||
col(PlatformMessageHistory.created_at) < cutoff_time,
|
||||
col(PlatformMessageHistory.created_at) >= cutoff_time,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -709,3 +710,101 @@ class SQLiteDatabase(BaseDatabase):
|
||||
t.start()
|
||||
t.join()
|
||||
return result
|
||||
|
||||
# ====
|
||||
# Platform Session Management
|
||||
# ====
|
||||
|
||||
async def create_platform_session(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str = "webchat",
|
||||
session_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
is_group: int = 0,
|
||||
) -> PlatformSession:
|
||||
"""Create a new Platform session."""
|
||||
kwargs = {}
|
||||
if session_id:
|
||||
kwargs["session_id"] = session_id
|
||||
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
new_session = PlatformSession(
|
||||
creator=creator,
|
||||
platform_id=platform_id,
|
||||
display_name=display_name,
|
||||
is_group=is_group,
|
||||
**kwargs,
|
||||
)
|
||||
session.add(new_session)
|
||||
await session.flush()
|
||||
await session.refresh(new_session)
|
||||
return new_session
|
||||
|
||||
async def get_platform_session_by_id(
|
||||
self, session_id: str
|
||||
) -> PlatformSession | None:
|
||||
"""Get a Platform session by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(PlatformSession).where(
|
||||
PlatformSession.session_id == session_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_platform_sessions_by_creator(
|
||||
self,
|
||||
creator: str,
|
||||
platform_id: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformSession]:
|
||||
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
query = select(PlatformSession).where(PlatformSession.creator == creator)
|
||||
|
||||
if platform_id:
|
||||
query = query.where(PlatformSession.platform_id == platform_id)
|
||||
|
||||
query = (
|
||||
query.order_by(desc(PlatformSession.updated_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_platform_session(
|
||||
self,
|
||||
session_id: str,
|
||||
display_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
|
||||
await session.execute(
|
||||
update(PlatformSession)
|
||||
.where(col(PlatformSession.session_id == session_id))
|
||||
.values(**values),
|
||||
)
|
||||
|
||||
async def delete_platform_session(self, session_id: str) -> None:
|
||||
"""Delete a Platform session by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(PlatformSession).where(
|
||||
col(PlatformSession.session_id == session_id),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
class ResultData(TypedDict):
|
||||
id: str
|
||||
doc_id: str
|
||||
text: str
|
||||
metadata: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
similarity: float
|
||||
data: dict
|
||||
data: ResultData | dict
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
|
||||
9
astrbot/core/exceptions.py
Normal file
9
astrbot/core/exceptions.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AstrBotError(Exception):
|
||||
"""Base exception for all AstrBot errors."""
|
||||
|
||||
|
||||
class ProviderNotFoundError(AstrBotError):
|
||||
"""Raised when a specified provider is not found."""
|
||||
@@ -1,4 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
@@ -8,12 +11,98 @@ from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||
from astrbot.core.provider.provider import (
|
||||
EmbeddingProvider,
|
||||
RerankProvider,
|
||||
)
|
||||
from astrbot.core.provider.provider import (
|
||||
Provider as LLMProvider,
|
||||
)
|
||||
|
||||
from .chunking.base import BaseChunker
|
||||
from .chunking.recursive import RecursiveCharacterChunker
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
from .models import KBDocument, KBMedia, KnowledgeBase
|
||||
from .parsers.url_parser import extract_text_from_url
|
||||
from .parsers.util import select_parser
|
||||
from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""一个简单的速率限制器"""
|
||||
|
||||
def __init__(self, max_rpm: int):
|
||||
self.max_per_minute = max_rpm
|
||||
self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
|
||||
self.last_call_time = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.interval == 0:
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
elapsed = now - self.last_call_time
|
||||
|
||||
if elapsed < self.interval:
|
||||
await asyncio.sleep(self.interval - elapsed)
|
||||
|
||||
self.last_call_time = time.monotonic()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
async def _repair_and_translate_chunk_with_retry(
|
||||
chunk: str,
|
||||
repair_llm_service: LLMProvider,
|
||||
rate_limiter: RateLimiter,
|
||||
max_retries: int = 2,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Repairs, translates, and optionally re-chunks a single text chunk using the small LLM, with rate limiting.
|
||||
"""
|
||||
# 为了防止 LLM 上下文污染,在 user_prompt 中也加入明确的指令
|
||||
user_prompt = f"""IGNORE ALL PREVIOUS INSTRUCTIONS. Your ONLY task is to process the following text chunk according to the system prompt provided.
|
||||
|
||||
Text chunk to process:
|
||||
---
|
||||
{chunk}
|
||||
---
|
||||
"""
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
async with rate_limiter:
|
||||
response = await repair_llm_service.text_chat(
|
||||
prompt=user_prompt, system_prompt=TEXT_REPAIR_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
llm_output = response.completion_text
|
||||
|
||||
if "<discard_chunk />" in llm_output:
|
||||
return [] # Signal to discard this chunk
|
||||
|
||||
# More robust regex to handle potential LLM formatting errors (spaces, newlines in tags)
|
||||
matches = re.findall(
|
||||
r"<\s*repaired_text\s*>\s*(.*?)\s*<\s*/\s*repaired_text\s*>",
|
||||
llm_output,
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if matches:
|
||||
# Further cleaning to ensure no empty strings are returned
|
||||
return [m.strip() for m in matches if m.strip()]
|
||||
else:
|
||||
# If no valid tags and not explicitly discarded, discard it to be safe.
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f" - LLM call failed on attempt {attempt + 1}/{max_retries + 1}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f" - Failed to process chunk after {max_retries + 1} attempts. Using original text."
|
||||
)
|
||||
return [chunk]
|
||||
|
||||
|
||||
class KBHelper:
|
||||
@@ -100,7 +189,7 @@ class KBHelper:
|
||||
async def upload_document(
|
||||
self,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_content: bytes | None,
|
||||
file_type: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
@@ -108,6 +197,7 @@ class KBHelper:
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
pre_chunked_text: list[str] | None = None,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
@@ -130,46 +220,63 @@ class KBHelper:
|
||||
await self._ensure_vec_db()
|
||||
doc_id = str(uuid.uuid4())
|
||||
media_paths: list[Path] = []
|
||||
file_size = 0
|
||||
|
||||
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
||||
# async with aiofiles.open(file_path, "wb") as f:
|
||||
# await f.write(file_content)
|
||||
|
||||
try:
|
||||
# 阶段1: 解析文档
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 0, 100)
|
||||
|
||||
parser = await select_parser(f".{file_type}")
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 100, 100)
|
||||
|
||||
# 保存媒体文件
|
||||
chunks_text = []
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await self._save_media(
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
|
||||
if pre_chunked_text is not None:
|
||||
# 如果提供了预分块文本,直接使用
|
||||
chunks_text = pre_chunked_text
|
||||
file_size = sum(len(chunk) for chunk in chunks_text)
|
||||
logger.info(f"使用预分块文本进行上传,共 {len(chunks_text)} 个块。")
|
||||
else:
|
||||
# 否则,执行标准的文件解析和分块流程
|
||||
if file_content is None:
|
||||
raise ValueError(
|
||||
"当未提供 pre_chunked_text 时,file_content 不能为空。"
|
||||
)
|
||||
|
||||
file_size = len(file_content)
|
||||
|
||||
# 阶段1: 解析文档
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 0, 100)
|
||||
|
||||
parser = await select_parser(f".{file_type}")
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 100, 100)
|
||||
|
||||
# 保存媒体文件
|
||||
for media_item in media_items:
|
||||
media = await self._save_media(
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 阶段2: 分块
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 0, 100)
|
||||
|
||||
chunks_text = await self.chunker.chunk(
|
||||
text_content,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 阶段2: 分块
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 0, 100)
|
||||
|
||||
chunks_text = await self.chunker.chunk(
|
||||
text_content,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
contents = []
|
||||
metadatas = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
@@ -205,7 +312,7 @@ class KBHelper:
|
||||
kb_id=self.kb.kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
file_size=file_size,
|
||||
# file_path=str(file_path),
|
||||
file_path="",
|
||||
chunk_count=len(chunks_text),
|
||||
@@ -359,3 +466,177 @@ class KBHelper:
|
||||
)
|
||||
|
||||
return media
|
||||
|
||||
async def upload_from_url(
|
||||
self,
|
||||
url: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
enable_cleaning: bool = False,
|
||||
cleaning_provider_id: str | None = None,
|
||||
) -> KBDocument:
|
||||
"""从 URL 上传并处理文档(带原子性保证和失败清理)
|
||||
Args:
|
||||
url: 要提取内容的网页 URL
|
||||
chunk_size: 文本块大小
|
||||
chunk_overlap: 文本块重叠大小
|
||||
batch_size: 批处理大小
|
||||
tasks_limit: 并发任务限制
|
||||
max_retries: 最大重试次数
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
||||
- stage: 当前阶段 ('extracting', 'cleaning', 'parsing', 'chunking', 'embedding')
|
||||
- current: 当前进度
|
||||
- total: 总数
|
||||
Returns:
|
||||
KBDocument: 上传的文档对象
|
||||
Raises:
|
||||
ValueError: 如果 URL 为空或无法提取内容
|
||||
IOError: 如果网络请求失败
|
||||
"""
|
||||
# 获取 Tavily API 密钥
|
||||
config = self.prov_mgr.acm.default_conf
|
||||
tavily_keys = config.get("provider_settings", {}).get(
|
||||
"websearch_tavily_key", []
|
||||
)
|
||||
if not tavily_keys:
|
||||
raise ValueError(
|
||||
"Error: Tavily API key is not configured in provider_settings."
|
||||
)
|
||||
|
||||
# 阶段1: 从 URL 提取内容
|
||||
if progress_callback:
|
||||
await progress_callback("extracting", 0, 100)
|
||||
|
||||
try:
|
||||
text_content = await extract_text_from_url(url, tavily_keys)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract content from URL {url}: {e}")
|
||||
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
|
||||
|
||||
if not text_content:
|
||||
raise ValueError(f"No content extracted from URL: {url}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("extracting", 100, 100)
|
||||
|
||||
# 阶段2: (可选)清洗内容并分块
|
||||
final_chunks = await self._clean_and_rechunk_content(
|
||||
content=text_content,
|
||||
url=url,
|
||||
progress_callback=progress_callback,
|
||||
enable_cleaning=enable_cleaning,
|
||||
cleaning_provider_id=cleaning_provider_id,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
if enable_cleaning and not final_chunks:
|
||||
raise ValueError(
|
||||
"内容清洗后未提取到有效文本。请尝试关闭内容清洗功能,或更换更高性能的LLM模型后重试。"
|
||||
)
|
||||
|
||||
# 创建一个虚拟文件名
|
||||
file_name = url.split("/")[-1] or f"document_from_{url}"
|
||||
if not Path(file_name).suffix:
|
||||
file_name += ".url"
|
||||
|
||||
# 复用现有的 upload_document 方法,但传入预分块文本
|
||||
return await self.upload_document(
|
||||
file_name=file_name,
|
||||
file_content=None,
|
||||
file_type="url", # 使用 'url' 作为特殊文件类型
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
pre_chunked_text=final_chunks,
|
||||
)
|
||||
|
||||
async def _clean_and_rechunk_content(
|
||||
self,
|
||||
content: str,
|
||||
url: str,
|
||||
progress_callback=None,
|
||||
enable_cleaning: bool = False,
|
||||
cleaning_provider_id: str | None = None,
|
||||
repair_max_rpm: int = 60,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
) -> list[str]:
|
||||
"""
|
||||
对从 URL 获取的内容进行清洗、修复、翻译和重新分块。
|
||||
"""
|
||||
if not enable_cleaning:
|
||||
# 如果不启用清洗,则使用从前端传递的参数进行分块
|
||||
logger.info(
|
||||
f"内容清洗未启用,使用指定参数进行分块: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}"
|
||||
)
|
||||
return await self.chunker.chunk(
|
||||
content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
if not cleaning_provider_id:
|
||||
logger.warning(
|
||||
"启用了内容清洗,但未提供 cleaning_provider_id,跳过清洗并使用默认分块。"
|
||||
)
|
||||
return await self.chunker.chunk(content)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("cleaning", 0, 100)
|
||||
|
||||
try:
|
||||
# 获取指定的 LLM Provider
|
||||
llm_provider = await self.prov_mgr.get_provider_by_id(cleaning_provider_id)
|
||||
if not llm_provider or not isinstance(llm_provider, LLMProvider):
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {cleaning_provider_id} 的 LLM Provider 或类型不正确"
|
||||
)
|
||||
|
||||
# 初步分块
|
||||
# 优化分隔符,优先按段落分割,以获得更高质量的文本块
|
||||
text_splitter = RecursiveCharacterChunker(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separators=["\n\n", "\n", " "], # 优先使用段落分隔符
|
||||
)
|
||||
initial_chunks = await text_splitter.chunk(content)
|
||||
logger.info(f"初步分块完成,生成 {len(initial_chunks)} 个块用于修复。")
|
||||
|
||||
# 并发处理所有块
|
||||
rate_limiter = RateLimiter(repair_max_rpm)
|
||||
tasks = [
|
||||
_repair_and_translate_chunk_with_retry(
|
||||
chunk, llm_provider, rate_limiter
|
||||
)
|
||||
for chunk in initial_chunks
|
||||
]
|
||||
|
||||
repaired_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
final_chunks = []
|
||||
for i, result in enumerate(repaired_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"块 {i} 处理异常: {str(result)}. 回退到原始块。")
|
||||
final_chunks.append(initial_chunks[i])
|
||||
elif isinstance(result, list):
|
||||
final_chunks.extend(result)
|
||||
|
||||
logger.info(
|
||||
f"文本修复完成: {len(initial_chunks)} 个原始块 -> {len(final_chunks)} 个最终块。"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("cleaning", 100, 100)
|
||||
|
||||
return final_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"使用 Provider '{cleaning_provider_id}' 清洗内容失败: {e}")
|
||||
# 清洗失败,返回默认分块结果,保证流程不中断
|
||||
return await self.chunker.chunk(content)
|
||||
|
||||
@@ -8,7 +8,7 @@ from astrbot.core.provider.manager import ProviderManager
|
||||
from .chunking.recursive import RecursiveCharacterChunker
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
from .kb_helper import KBHelper
|
||||
from .models import KnowledgeBase
|
||||
from .models import KBDocument, KnowledgeBase
|
||||
from .retrieval.manager import RetrievalManager, RetrievalResult
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
@@ -284,3 +284,47 @@ class KnowledgeBaseManager:
|
||||
await self.kb_db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭知识库元数据数据库失败: {e}")
|
||||
|
||||
async def upload_from_url(
|
||||
self,
|
||||
kb_id: str,
|
||||
url: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> KBDocument:
|
||||
"""从 URL 上传文档到指定的知识库
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
url: 要提取内容的网页 URL
|
||||
chunk_size: 文本块大小
|
||||
chunk_overlap: 文本块重叠大小
|
||||
batch_size: 批处理大小
|
||||
tasks_limit: 并发任务限制
|
||||
max_retries: 最大重试次数
|
||||
progress_callback: 进度回调函数
|
||||
|
||||
Returns:
|
||||
KBDocument: 上传的文档对象
|
||||
|
||||
Raises:
|
||||
ValueError: 如果知识库不存在或 URL 为空
|
||||
IOError: 如果网络请求失败
|
||||
"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
raise ValueError(f"Knowledge base with id {kb_id} not found.")
|
||||
|
||||
return await kb_helper.upload_from_url(
|
||||
url=url,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
103
astrbot/core/knowledge_base/parsers/url_parser.py
Normal file
103
astrbot/core/knowledge_base/parsers/url_parser.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class URLExtractor:
|
||||
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
|
||||
|
||||
def __init__(self, tavily_keys: list[str]):
|
||||
"""
|
||||
初始化 URL 提取器
|
||||
|
||||
Args:
|
||||
tavily_keys: Tavily API 密钥列表
|
||||
"""
|
||||
if not tavily_keys:
|
||||
raise ValueError("Error: Tavily API keys are not configured.")
|
||||
|
||||
self.tavily_keys = tavily_keys
|
||||
self.tavily_key_index = 0
|
||||
self.tavily_key_lock = asyncio.Lock()
|
||||
|
||||
async def _get_tavily_key(self) -> str:
|
||||
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
|
||||
async with self.tavily_key_lock:
|
||||
key = self.tavily_keys[self.tavily_key_index]
|
||||
self.tavily_key_index = (self.tavily_key_index + 1) % len(self.tavily_keys)
|
||||
return key
|
||||
|
||||
async def extract_text_from_url(self, url: str) -> str:
|
||||
"""
|
||||
使用 Tavily API 从 URL 提取主要文本内容。
|
||||
这是 web_searcher 插件中 tavily_extract_web_page 方法的简化版本,
|
||||
专门为知识库模块设计,不依赖 AstrMessageEvent。
|
||||
|
||||
Args:
|
||||
url: 要提取内容的网页 URL
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 URL 为空或 API 密钥未配置
|
||||
IOError: 如果请求失败或返回错误
|
||||
"""
|
||||
if not url:
|
||||
raise ValueError("Error: url must be a non-empty string.")
|
||||
|
||||
tavily_key = await self._get_tavily_key()
|
||||
api_url = "https://api.tavily.com/extract"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {tavily_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"urls": [url],
|
||||
"extract_depth": "basic", # 使用基础提取深度
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(
|
||||
api_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0, # 增加超时时间,因为内容提取可能需要更长时间
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
reason = await response.text()
|
||||
raise OSError(
|
||||
f"Tavily web extraction failed: {reason}, status: {response.status}"
|
||||
)
|
||||
|
||||
data = await response.json()
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
raise ValueError(f"No content extracted from URL: {url}")
|
||||
|
||||
# 返回第一个结果的内容
|
||||
return results[0].get("raw_content", "")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise OSError(f"Failed to fetch URL {url}: {e}") from e
|
||||
except Exception as e:
|
||||
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
|
||||
|
||||
|
||||
# 为了向后兼容,提供一个简单的函数接口
|
||||
async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str:
|
||||
"""
|
||||
简单的函数接口,用于从 URL 提取文本内容
|
||||
|
||||
Args:
|
||||
url: 要提取内容的网页 URL
|
||||
tavily_keys: Tavily API 密钥列表
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
"""
|
||||
extractor = URLExtractor(tavily_keys)
|
||||
return await extractor.extract_text_from_url(url)
|
||||
65
astrbot/core/knowledge_base/prompts.py
Normal file
65
astrbot/core/knowledge_base/prompts.py
Normal file
@@ -0,0 +1,65 @@
|
||||
TEXT_REPAIR_SYSTEM_PROMPT = """You are a meticulous digital archivist. Your mission is to reconstruct a clean, readable article from raw, noisy text chunks.
|
||||
|
||||
**Core Task:**
|
||||
1. **Analyze:** Examine the text chunk to separate "signal" (substantive information) from "noise" (UI elements, ads, navigation, footers).
|
||||
2. **Process:** Clean and repair the signal. **Do not translate it.** Keep the original language.
|
||||
|
||||
**Crucial Rules:**
|
||||
- **NEVER discard a chunk if it contains ANY valuable information.** Your primary duty is to salvage content.
|
||||
- **If a chunk contains multiple distinct topics, split them.** Enclose each topic in its own `<repaired_text>` tag.
|
||||
- Your output MUST be ONLY `<repaired_text>...</repaired_text>` tags or a single `<discard_chunk />` tag.
|
||||
|
||||
---
|
||||
**Example 1: Chunk with Noise and Signal**
|
||||
|
||||
*Input Chunk:*
|
||||
"Home | About | Products | **The Llama is a domesticated South American camelid.** | © 2025 ACME Corp."
|
||||
|
||||
*Your Thought Process:*
|
||||
1. "Home | About | Products..." and "© 2025 ACME Corp." are noise.
|
||||
2. "The Llama is a domesticated..." is the signal.
|
||||
3. I must extract the signal and wrap it.
|
||||
|
||||
*Your Output:*
|
||||
<repaired_text>
|
||||
The Llama is a domesticated South American camelid.
|
||||
</repaired_text>
|
||||
|
||||
---
|
||||
**Example 2: Chunk with ONLY Noise**
|
||||
|
||||
*Input Chunk:*
|
||||
"Next Page > | Subscribe to our newsletter | Follow us on X"
|
||||
|
||||
*Your Thought Process:*
|
||||
1. This entire chunk is noise. There is no signal.
|
||||
2. I must discard this.
|
||||
|
||||
*Your Output:*
|
||||
<discard_chunk />
|
||||
|
||||
---
|
||||
**Example 3: Chunk with Multiple Topics (Requires Splitting)**
|
||||
|
||||
*Input Chunk:*
|
||||
"## Chapter 1: The Sun
|
||||
The Sun is the star at the center of the Solar System.
|
||||
|
||||
## Chapter 2: The Moon
|
||||
The Moon is Earth's only natural satellite."
|
||||
|
||||
*Your Thought Process:*
|
||||
1. This chunk contains two distinct topics.
|
||||
2. I must process them separately to maintain semantic integrity.
|
||||
3. I will create two `<repaired_text>` blocks.
|
||||
|
||||
*Your Output:*
|
||||
<repaired_text>
|
||||
## Chapter 1: The Sun
|
||||
The Sun is the star at the center of the Solar System.
|
||||
</repaired_text>
|
||||
<repaired_text>
|
||||
## Chapter 2: The Moon
|
||||
The Moon is Earth's only natural satellite.
|
||||
</repaired_text>
|
||||
"""
|
||||
822
astrbot/core/memory/DESIGN.excalidraw
Normal file
822
astrbot/core/memory/DESIGN.excalidraw
Normal file
@@ -0,0 +1,822 @@
|
||||
{
|
||||
"type": "excalidraw",
|
||||
"version": 2,
|
||||
"source": "https://marketplace.visualstudio.com/items?itemName=pomdtr.excalidraw-editor",
|
||||
"elements": [
|
||||
{
|
||||
"id": "l6cYurMvF69IM4Kc33Qou",
|
||||
"type": "rectangle",
|
||||
"x": 173.140625,
|
||||
"y": -29.0234375,
|
||||
"width": 92.95703125,
|
||||
"height": 77.109375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a0",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1409469537,
|
||||
"version": 91,
|
||||
"versionNonce": 307958671,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "1ZvS6t8U6ihUjNU0dakgl",
|
||||
"type": "arrow",
|
||||
"x": 409.30859375,
|
||||
"y": 9.6875,
|
||||
"width": 118.2734375,
|
||||
"height": 1.9609375,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a1",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 326508865,
|
||||
"version": 120,
|
||||
"versionNonce": 199367023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703733605,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
-118.2734375,
|
||||
-1.9609375
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "tfdUGiJdcMoOHGfqFHXK6",
|
||||
"type": "text",
|
||||
"x": 153.46875,
|
||||
"y": -70.9765625,
|
||||
"width": 136.4598846435547,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a2",
|
||||
"roundness": null,
|
||||
"seed": 688712865,
|
||||
"version": 67,
|
||||
"versionNonce": 300660705,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703743816,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FAISS+SQLite",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FAISS+SQLite",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "AeL3kEB9a8_TAvAXpAbpl",
|
||||
"type": "text",
|
||||
"x": 438.36328125,
|
||||
"y": -3.78125,
|
||||
"width": 116.109375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a3",
|
||||
"roundness": null,
|
||||
"seed": 788579535,
|
||||
"version": 33,
|
||||
"versionNonce": 946602095,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703932431,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "FACT",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "FACT",
|
||||
"autoResize": false,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Pe3TeMZvxQ8tRTcbD5v6P",
|
||||
"type": "arrow",
|
||||
"x": 297.125,
|
||||
"y": 40.2578125,
|
||||
"width": 120.2421875,
|
||||
"height": 1.421875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a4",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1146229999,
|
||||
"version": 44,
|
||||
"versionNonce": 636917679,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703759050,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
120.2421875,
|
||||
1.421875
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": null,
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "GhmQoadtQRK8c8aEEbYKQ",
|
||||
"type": "text",
|
||||
"x": 283.53515625,
|
||||
"y": 64.76171875,
|
||||
"width": 130.85989379882812,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a5",
|
||||
"roundness": null,
|
||||
"seed": 1445650959,
|
||||
"version": 79,
|
||||
"versionNonce": 566193167,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703768982,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "top-n Similary\n",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "top-n Similary\n",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"type": "rectangle",
|
||||
"x": 528.1586158430439,
|
||||
"y": -173.43472375183552,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a6",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 223409231,
|
||||
"version": 44,
|
||||
"versionNonce": 1066827105,
|
||||
"isDeleted": false,
|
||||
"boundElements": [
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow"
|
||||
}
|
||||
],
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "2SzqzpJ4C2ymVj8-8vN7H",
|
||||
"type": "text",
|
||||
"x": 548.1480270948795,
|
||||
"y": -211,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "a7",
|
||||
"roundness": null,
|
||||
"seed": 1015608623,
|
||||
"version": 23,
|
||||
"versionNonce": 950374849,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704047884,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "CgW6Yf9v0a9q1tsjhDl7b",
|
||||
"type": "text",
|
||||
"x": 568.3099317299038,
|
||||
"y": -154.69469411681115,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aA",
|
||||
"roundness": null,
|
||||
"seed": 452254927,
|
||||
"version": 10,
|
||||
"versionNonce": 972895023,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "knvlKpaFZ8lY-73Y-e9W6",
|
||||
"type": "text",
|
||||
"x": 569.11328125,
|
||||
"y": -116.91056665512056,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aB",
|
||||
"roundness": null,
|
||||
"seed": 914644015,
|
||||
"version": 90,
|
||||
"versionNonce": 158135631,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704057762,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Q7URqvTSMpvj08ye-afTT",
|
||||
"type": "rectangle",
|
||||
"x": 444.515625,
|
||||
"y": 36.7890625,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aC",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1642537601,
|
||||
"version": 19,
|
||||
"versionNonce": 948406575,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703870173,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "JjxBt9cZIZXNTd6CmwyKL",
|
||||
"type": "rectangle",
|
||||
"x": 452.203125,
|
||||
"y": 46.064453125,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aD",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 1746916641,
|
||||
"version": 40,
|
||||
"versionNonce": 1650978255,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703871882,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "XGBCPPFnjriqsL8LvLwyQ",
|
||||
"type": "rectangle",
|
||||
"x": 461.56640625,
|
||||
"y": 56.162109375,
|
||||
"width": 58.859375,
|
||||
"height": 29.41796875,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aE",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 529794575,
|
||||
"version": 85,
|
||||
"versionNonce": 2131900641,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763703874182,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "FfWdx1_yCq6UYfXamJX9N",
|
||||
"type": "arrow",
|
||||
"x": 537.6875,
|
||||
"y": 48.203125,
|
||||
"width": 6.615850226297994,
|
||||
"height": 75.81335873223107,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aF",
|
||||
"roundness": {
|
||||
"type": 2
|
||||
},
|
||||
"seed": 1982870689,
|
||||
"version": 90,
|
||||
"versionNonce": 25307457,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704050188,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"points": [
|
||||
[
|
||||
0,
|
||||
0
|
||||
],
|
||||
[
|
||||
6.615850226297994,
|
||||
-75.81335873223107
|
||||
]
|
||||
],
|
||||
"lastCommittedPoint": null,
|
||||
"startBinding": null,
|
||||
"endBinding": {
|
||||
"elementId": "uTEFJs8cNS09WFq2pi9P7",
|
||||
"focus": 0.6071885090336794,
|
||||
"gap": 24.64453125
|
||||
},
|
||||
"startArrowhead": null,
|
||||
"endArrowhead": "arrow",
|
||||
"elbowed": false
|
||||
},
|
||||
{
|
||||
"id": "jgJgqGMRWcaNX_28wY4CU",
|
||||
"type": "text",
|
||||
"x": 570,
|
||||
"y": 10,
|
||||
"width": 67.11994934082031,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aG",
|
||||
"roundness": null,
|
||||
"seed": 1065220559,
|
||||
"version": 26,
|
||||
"versionNonce": 2115991521,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703959397,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "update",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "update",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "_5pSPPOpp9h1TpFCIc055",
|
||||
"type": "text",
|
||||
"x": 292.36328125,
|
||||
"y": -138.5703125,
|
||||
"width": 122.87992858886719,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aH",
|
||||
"roundness": null,
|
||||
"seed": 51461025,
|
||||
"version": 26,
|
||||
"versionNonce": 1647492655,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763703925147,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "ADD Memory",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "ADD Memory",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "YG6MdL14l7lk4ypQNMZ_k",
|
||||
"type": "text",
|
||||
"x": 296.71885397566257,
|
||||
"y": 161.399157096715,
|
||||
"width": 295.27984619140625,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aJ",
|
||||
"roundness": null,
|
||||
"seed": 1183210273,
|
||||
"version": 122,
|
||||
"versionNonce": 1702733281,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704085083,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RETRIEVE Memory (STATIC)",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RETRIEVE Memory (STATIC)",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "Foa3VPJYqhj1uAX5mn3n0",
|
||||
"type": "rectangle",
|
||||
"x": 324.7616636099071,
|
||||
"y": 248.63213980937013,
|
||||
"width": 135.7578125,
|
||||
"height": 128.73828125,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aL",
|
||||
"roundness": {
|
||||
"type": 3
|
||||
},
|
||||
"seed": 995116257,
|
||||
"version": 225,
|
||||
"versionNonce": 1886900225,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false
|
||||
},
|
||||
{
|
||||
"id": "pe3veI_yBFKYtbaJwDKQT",
|
||||
"type": "text",
|
||||
"x": 344.7510748617428,
|
||||
"y": 211.06686356120565,
|
||||
"width": 86.43992614746094,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aM",
|
||||
"roundness": null,
|
||||
"seed": 26673345,
|
||||
"version": 204,
|
||||
"versionNonce": 1004546017,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704055846,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "Memories",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "Memories",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "bOlhO8AaKE86_43viu5UG",
|
||||
"type": "text",
|
||||
"x": 365.50408375566445,
|
||||
"y": 269.24725381983865,
|
||||
"width": 62.099945068359375,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aN",
|
||||
"roundness": null,
|
||||
"seed": 1849784033,
|
||||
"version": 106,
|
||||
"versionNonce": 762320737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk1",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk1",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "V_iDW10PKwMe7vWb5S5HF",
|
||||
"type": "text",
|
||||
"x": 366.3074332757606,
|
||||
"y": 307.03138128152926,
|
||||
"width": 67.55995178222656,
|
||||
"height": 25,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aO",
|
||||
"roundness": null,
|
||||
"seed": 1670509249,
|
||||
"version": 186,
|
||||
"versionNonce": 1964540737,
|
||||
"isDeleted": false,
|
||||
"boundElements": [],
|
||||
"updated": 1763704060295,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "chunk2",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "chunk2",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
},
|
||||
{
|
||||
"id": "LHKMRdSowgcl2LsKacxTz",
|
||||
"type": "text",
|
||||
"x": 484.9493410573871,
|
||||
"y": 292.45619471187945,
|
||||
"width": 273.579833984375,
|
||||
"height": 50,
|
||||
"angle": 0,
|
||||
"strokeColor": "#1e1e1e",
|
||||
"backgroundColor": "transparent",
|
||||
"fillStyle": "solid",
|
||||
"strokeWidth": 2,
|
||||
"strokeStyle": "solid",
|
||||
"roughness": 1,
|
||||
"opacity": 100,
|
||||
"groupIds": [],
|
||||
"frameId": null,
|
||||
"index": "aP",
|
||||
"roundness": null,
|
||||
"seed": 945666991,
|
||||
"version": 104,
|
||||
"versionNonce": 1512137505,
|
||||
"isDeleted": false,
|
||||
"boundElements": null,
|
||||
"updated": 1763704096016,
|
||||
"link": null,
|
||||
"locked": false,
|
||||
"text": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"fontSize": 20,
|
||||
"fontFamily": 5,
|
||||
"textAlign": "left",
|
||||
"verticalAlign": "top",
|
||||
"containerId": null,
|
||||
"originalText": "RANKED By DECAY SCORE,\nTOP K",
|
||||
"autoResize": true,
|
||||
"lineHeight": 1.25
|
||||
}
|
||||
],
|
||||
"appState": {
|
||||
"gridSize": 20,
|
||||
"gridStep": 5,
|
||||
"gridModeEnabled": false,
|
||||
"viewBackgroundColor": "#ffffff"
|
||||
},
|
||||
"files": {}
|
||||
}
|
||||
76
astrbot/core/memory/_README.md
Normal file
76
astrbot/core/memory/_README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
## Decay Score
|
||||
|
||||
记忆衰减分数定义为:
|
||||
|
||||
\[
|
||||
\text{decay\_score}
|
||||
= \alpha \cdot e^{-\lambda \cdot \Delta t \cdot \beta}
|
||||
|
||||
+ (1-\alpha)\cdot (1 - e^{-\gamma \cdot c})
|
||||
\]
|
||||
|
||||
其中:
|
||||
|
||||
+ \(\Delta t\):自上次检索以来经过的时间(天),由 `last_retrieval_at` 计算;
|
||||
+ \(c\):检索次数,对应字段 `retrieval_count`;
|
||||
+ \(\alpha\):控制时间衰减和检索次数影响的权重;
|
||||
+ \(\gamma\):控制检索次数影响的速率;
|
||||
+ \(\lambda\):控制时间衰减的速率;
|
||||
+ \(\beta\):时间衰减调节因子;
|
||||
|
||||
\[
|
||||
\beta = \frac{1}{1 + a \cdot c}
|
||||
\]
|
||||
|
||||
+ \(a\):控制检索次数对时间衰减影响的权重。
|
||||
|
||||
## ADD MEMORY
|
||||
|
||||
+ LLM 通过 `astr_add_memory` 工具调用,传入记忆内容和记忆类型。
|
||||
+ 生成 `mem_id = uuid4()`。
|
||||
+ 从上下文中获取 `owner_id = unified_message_origin`。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 使用 VecDB 以新记忆内容为 query,检索前 20 条相似记忆。
|
||||
2. 从中取相似度最高的前 5 条:
|
||||
+ 若相似度超过“合并阈值”(如 `sim >= merge_threshold`):
|
||||
+ 将该条记忆视为同一记忆,使用 LLM 将旧内容与新内容合并;
|
||||
+ 在同一个 `mem_id` 上更新 MemoryDB 和 VecDB(UPDATE,而非新建)。
|
||||
+ 否则:
|
||||
+ 作为全新的记忆插入:
|
||||
+ 写入 VecDB(metadata 中包含 `mem_id`, `owner_id`);
|
||||
+ 写入 MemoryDB 的 `memory_chunks` 表,初始化:
|
||||
+ `created_at = now`
|
||||
+ `last_retrieval_at = now`
|
||||
+ `retrieval_count = 1` 等。
|
||||
3. 对 VecDB 返回的前 20 条记忆,如果相似度高于某个“赫布阈值”(`hebb_threshold`),则:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
这一步体现了赫布学习:与新记忆共同被激活的旧记忆会获得一次强化。
|
||||
|
||||
## QUERY MEMORY (STATIC)
|
||||
|
||||
+ LLM 通过 `astr_query_memory` 工具调用,无参数。
|
||||
|
||||
步骤:
|
||||
|
||||
1. 从 MemoryDB 的 `memory_chunks` 表中查询当前用户所有活跃记忆:
|
||||
+ `SELECT * FROM memory_chunks WHERE owner_id = ? AND is_active = 1`
|
||||
2. 对每条记忆,根据 `last_retrieval_at` 和 `retrieval_count` 计算对应的 `decay_score`。
|
||||
3. 按 `decay_score` 从高到低排序,返回前 `top_k` 条记忆内容给 LLM。
|
||||
4. 对返回的这 `top_k` 条记忆:
|
||||
+ `retrieval_count += 1`
|
||||
+ `last_retrieval_at = now`
|
||||
|
||||
## QUERY MEMORY (DYNAMIC)(暂不实现)
|
||||
|
||||
+ LLM 提供查询内容作为语义 query。
|
||||
+ 使用 VecDB 检索与该 query 最相似的前 `N` 条记忆(`N > top_k`)。
|
||||
+ 根据 `mem_id` 从 `memory_chunks` 中加载对应记录。
|
||||
+ 对这批候选记忆计算:
|
||||
+ 语义相似度(来自 VecDB)
|
||||
+ `decay_score`
|
||||
+ 最终排序分数(例如 `w1 * sim + w2 * decay_score`)
|
||||
+ 按最终排序分数从高到低返回前 `top_k` 条记忆内容,并更新它们的 `retrieval_count` 和 `last_retrieval_at`。
|
||||
63
astrbot/core/memory/entities.py
Normal file
63
astrbot/core/memory/entities.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
from sqlmodel import Field, MetaData, SQLModel
|
||||
|
||||
MEMORY_TYPE_IMPORTANCE = {"persona": 1.3, "fact": 1.0, "ephemeral": 0.8}
|
||||
|
||||
|
||||
class BaseMemoryModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class MemoryChunk(BaseMemoryModel, table=True):
|
||||
"""A chunk of memory stored in the system."""
|
||||
|
||||
__tablename__ = "memory_chunks" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
default=None,
|
||||
)
|
||||
mem_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
fact: str = Field(nullable=False)
|
||||
"""The factual content of the memory chunk."""
|
||||
owner_id: str = Field(max_length=255, nullable=False, index=True)
|
||||
"""The identifier of the owner (user) of the memory chunk."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The timestamp when the memory chunk was created."""
|
||||
last_retrieval_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
"""The timestamp when the memory chunk was last retrieved."""
|
||||
retrieval_count: int = Field(default=1, nullable=False)
|
||||
"""The number of times the memory chunk has been retrieved."""
|
||||
memory_type: str = Field(max_length=20, nullable=False, default="fact")
|
||||
"""The type of memory (e.g., 'persona', 'fact', 'ephemeral')."""
|
||||
is_active: bool = Field(default=True, nullable=False)
|
||||
"""Whether the memory chunk is active."""
|
||||
|
||||
def compute_decay_score(self, current_time: datetime) -> float:
|
||||
"""Compute the decay score of the memory chunk based on time and retrievals."""
|
||||
# Constants for the decay formula
|
||||
alpha = 0.5
|
||||
gamma = 0.1
|
||||
lambda_ = 0.05
|
||||
a = 0.1
|
||||
|
||||
# Calculate delta_t in days
|
||||
delta_t = (current_time - self.last_retrieval_at).total_seconds() / 86400
|
||||
c = self.retrieval_count
|
||||
beta = 1 / (1 + a * c)
|
||||
decay_score = alpha * np.exp(-lambda_ * delta_t * beta) + (1 - alpha) * (
|
||||
1 - np.exp(-gamma * c)
|
||||
)
|
||||
return decay_score * MEMORY_TYPE_IMPORTANCE.get(self.memory_type, 1.0)
|
||||
174
astrbot/core/memory/mem_db_sqlite.py
Normal file
174
astrbot/core/memory/mem_db_sqlite.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlmodel import col
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
from .entities import BaseMemoryModel, MemoryChunk
|
||||
|
||||
|
||||
class MemoryDatabase:
|
||||
def __init__(self, db_path: str = "data/astr_memory/memory.db") -> None:
|
||||
"""Initialize memory database
|
||||
|
||||
Args:
|
||||
db_path: Database file path, default is data/astr_memory/memory.db
|
||||
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# Ensure directory exists
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create async engine
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""Get database session
|
||||
|
||||
Usage:
|
||||
async with mem_db.get_db() as session:
|
||||
# Perform database operations
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize database, create tables and configure SQLite parameters"""
|
||||
async with self.engine.begin() as conn:
|
||||
# Create all memory related tables
|
||||
await conn.run_sync(BaseMemoryModel.metadata.create_all)
|
||||
|
||||
# Configure SQLite performance optimization parameters
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
await self._create_indexes()
|
||||
self.inited = True
|
||||
logger.info(f"Memory database initialized: {self.db_path}")
|
||||
|
||||
async def _create_indexes(self) -> None:
|
||||
"""Create indexes for memory_chunks table"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
# Create memory chunks table indexes
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_mem_id "
|
||||
"ON memory_chunks(mem_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_id "
|
||||
"ON memory_chunks(owner_id)",
|
||||
),
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_mem_owner_active "
|
||||
"ON memory_chunks(owner_id, is_active)",
|
||||
),
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"Memory database closed: {self.db_path}")
|
||||
|
||||
async def insert_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Insert a new memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_memory_by_id(self, mem_id: str) -> MemoryChunk | None:
|
||||
"""Get memory chunk by mem_id"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(col(MemoryChunk.mem_id) == mem_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_memory(self, memory: MemoryChunk) -> MemoryChunk:
|
||||
"""Update an existing memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
session.add(memory)
|
||||
await session.commit()
|
||||
await session.refresh(memory)
|
||||
return memory
|
||||
|
||||
async def get_active_memories(self, owner_id: str) -> list[MemoryChunk]:
|
||||
"""Get all active memories for a user"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(MemoryChunk).where(
|
||||
col(MemoryChunk.owner_id) == owner_id,
|
||||
col(MemoryChunk.is_active) == True, # noqa: E712
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_retrieval_stats(
|
||||
self,
|
||||
mem_ids: list[str],
|
||||
current_time: datetime | None = None,
|
||||
) -> None:
|
||||
"""Update retrieval statistics for multiple memories"""
|
||||
if not mem_ids:
|
||||
return
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id).in_(mem_ids))
|
||||
.values(
|
||||
retrieval_count=MemoryChunk.retrieval_count + 1,
|
||||
last_retrieval_at=current_time,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def deactivate_memory(self, mem_id: str) -> bool:
|
||||
"""Deactivate a memory chunk"""
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(MemoryChunk)
|
||||
.where(col(MemoryChunk.mem_id) == mem_id)
|
||||
.values(is_active=False)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0 if result.rowcount else False # type: ignore
|
||||
281
astrbot/core/memory/memory_manager.py
Normal file
281
astrbot/core/memory/memory_manager.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import Provider as LLMProvider
|
||||
|
||||
from .entities import MemoryChunk
|
||||
from .mem_db_sqlite import MemoryDatabase
|
||||
|
||||
MERGE_THRESHOLD = 0.85
|
||||
"""Similarity threshold for merging memories"""
|
||||
HEBB_THRESHOLD = 0.70
|
||||
"""Similarity threshold for Hebbian learning reinforcement"""
|
||||
MERGE_SYSTEM_PROMPT = """You are a memory consolidation assistant. Your task is to merge two related memory entries into a single, comprehensive memory.
|
||||
|
||||
Input format:
|
||||
- Old memory: [existing memory content]
|
||||
- New memory: [new memory content to be integrated]
|
||||
|
||||
Your output should be a single, concise memory that combines the essential information from both entries. Preserve specific details, update outdated information, and eliminate redundancy. Output only the merged memory content without any explanations or meta-commentary."""
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""Manager for user long-term memory storage and retrieval"""
|
||||
|
||||
def __init__(self, memory_root_dir: str = "data/astr_memory"):
|
||||
self.memory_root_dir = Path(memory_root_dir)
|
||||
self.memory_root_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.mem_db: MemoryDatabase | None = None
|
||||
self.vec_db: FaissVecDB | None = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
merge_llm_provider: LLMProvider,
|
||||
):
|
||||
"""Initialize memory database and vector database"""
|
||||
# Initialize MemoryDB
|
||||
db_path = self.memory_root_dir / "memory.db"
|
||||
self.mem_db = MemoryDatabase(db_path.as_posix())
|
||||
await self.mem_db.initialize()
|
||||
|
||||
self.embedding_provider = embedding_provider
|
||||
self.merge_llm_provider = merge_llm_provider
|
||||
|
||||
# Initialize VecDB
|
||||
doc_store_path = self.memory_root_dir / "doc.db"
|
||||
index_store_path = self.memory_root_dir / "index.faiss"
|
||||
self.vec_db = FaissVecDB(
|
||||
doc_store_path=doc_store_path.as_posix(),
|
||||
index_store_path=index_store_path.as_posix(),
|
||||
embedding_provider=self.embedding_provider,
|
||||
)
|
||||
await self.vec_db.initialize()
|
||||
|
||||
logger.info("Memory manager initialized")
|
||||
self._initialized = True
|
||||
|
||||
async def terminate(self):
|
||||
"""Close all database connections"""
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
if self.mem_db:
|
||||
await self.mem_db.close()
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
fact: str,
|
||||
owner_id: str,
|
||||
memory_type: str = "fact",
|
||||
) -> MemoryChunk:
|
||||
"""Add a new memory with similarity check and merge logic
|
||||
|
||||
Implements the ADD MEMORY workflow from _README.md:
|
||||
1. Search for similar memories using VecDB
|
||||
2. If similarity >= merge_threshold, merge with existing memory
|
||||
3. Otherwise, create new memory
|
||||
4. Apply Hebbian learning to similar memories (similarity >= hebb_threshold)
|
||||
|
||||
Args:
|
||||
fact: Memory content
|
||||
owner_id: User identifier
|
||||
memory_type: Memory type ('persona', 'fact', 'ephemeral')
|
||||
|
||||
Returns:
|
||||
The created or updated MemoryChunk
|
||||
|
||||
"""
|
||||
if not self.vec_db or not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Search for similar memories
|
||||
similar_results = await self.vec_db.retrieve(
|
||||
query=fact,
|
||||
k=20,
|
||||
fetch_k=50,
|
||||
metadata_filters={"owner_id": owner_id},
|
||||
)
|
||||
|
||||
# Step 2: Check if we should merge with existing memories (top 3 similar ones)
|
||||
merge_candidates = [
|
||||
r for r in similar_results[:3] if r.similarity >= MERGE_THRESHOLD
|
||||
]
|
||||
|
||||
if merge_candidates:
|
||||
# Get all candidate memories from database
|
||||
candidate_memories: list[tuple[str, MemoryChunk]] = []
|
||||
for candidate in merge_candidates:
|
||||
mem_id = json.loads(candidate.data["metadata"])["mem_id"]
|
||||
memory = await self.mem_db.get_memory_by_id(mem_id)
|
||||
if memory:
|
||||
candidate_memories.append((mem_id, memory))
|
||||
|
||||
if candidate_memories:
|
||||
# Use the most similar memory as the base
|
||||
base_mem_id, base_memory = candidate_memories[0]
|
||||
|
||||
# Collect all facts to merge (existing candidates + new fact)
|
||||
all_facts = [mem.fact for _, mem in candidate_memories] + [fact]
|
||||
merged_fact = await self._merge_multiple_memories(all_facts)
|
||||
|
||||
# Update the base memory
|
||||
base_memory.fact = merged_fact
|
||||
base_memory.last_retrieval_at = current_time
|
||||
base_memory.retrieval_count += 1
|
||||
updated_memory = await self.mem_db.update_memory(base_memory)
|
||||
|
||||
# Update VecDB for base memory
|
||||
await self.vec_db.delete(base_mem_id)
|
||||
await self.vec_db.insert(
|
||||
content=merged_fact,
|
||||
metadata={
|
||||
"mem_id": base_mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=base_mem_id,
|
||||
)
|
||||
|
||||
# Deactivate and remove other merged memories
|
||||
for mem_id, _ in candidate_memories[1:]:
|
||||
await self.mem_db.deactivate_memory(mem_id)
|
||||
await self.vec_db.delete(mem_id)
|
||||
|
||||
logger.info(
|
||||
f"Merged {len(candidate_memories)} memories into {base_mem_id} for user {owner_id}"
|
||||
)
|
||||
return updated_memory
|
||||
|
||||
# Step 3: Create new memory
|
||||
mem_id = str(uuid.uuid4())
|
||||
new_memory = MemoryChunk(
|
||||
mem_id=mem_id,
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
created_at=current_time,
|
||||
last_retrieval_at=current_time,
|
||||
retrieval_count=1,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
# Insert into MemoryDB
|
||||
created_memory = await self.mem_db.insert_memory(new_memory)
|
||||
|
||||
# Insert into VecDB
|
||||
await self.vec_db.insert(
|
||||
content=fact,
|
||||
metadata={
|
||||
"mem_id": mem_id,
|
||||
"owner_id": owner_id,
|
||||
"memory_type": memory_type,
|
||||
},
|
||||
id=mem_id,
|
||||
)
|
||||
|
||||
# Step 4: Apply Hebbian learning to similar memories
|
||||
hebb_mem_ids = [
|
||||
json.loads(r.data["metadata"])["mem_id"]
|
||||
for r in similar_results
|
||||
if r.similarity >= HEBB_THRESHOLD
|
||||
]
|
||||
if hebb_mem_ids:
|
||||
await self.mem_db.update_retrieval_stats(hebb_mem_ids, current_time)
|
||||
logger.debug(
|
||||
f"Applied Hebbian learning to {len(hebb_mem_ids)} memories for user {owner_id}",
|
||||
)
|
||||
|
||||
logger.info(f"Created new memory {mem_id} for user {owner_id}")
|
||||
return created_memory
|
||||
|
||||
async def query_memory(
|
||||
self,
|
||||
owner_id: str,
|
||||
top_k: int = 5,
|
||||
) -> list[MemoryChunk]:
|
||||
"""Query user's memories using static retrieval with decay score ranking
|
||||
|
||||
Implements the QUERY MEMORY (STATIC) workflow from _README.md:
|
||||
1. Get all active memories for user from MemoryDB
|
||||
2. Compute decay_score for each memory
|
||||
3. Sort by decay_score and return top_k
|
||||
4. Update retrieval statistics for returned memories
|
||||
|
||||
Args:
|
||||
owner_id: User identifier
|
||||
top_k: Number of memories to return
|
||||
|
||||
Returns:
|
||||
List of top_k MemoryChunk sorted by decay score
|
||||
"""
|
||||
if not self.mem_db:
|
||||
raise RuntimeError("Memory manager not initialized")
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Step 1: Get all active memories for user
|
||||
all_memories = await self.mem_db.get_active_memories(owner_id)
|
||||
|
||||
if not all_memories:
|
||||
return []
|
||||
|
||||
# Step 2-3: Compute decay scores and sort
|
||||
memories_with_scores = [
|
||||
(mem, mem.compute_decay_score(current_time)) for mem in all_memories
|
||||
]
|
||||
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Get top_k memories
|
||||
top_memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
||||
|
||||
# Step 4: Update retrieval statistics
|
||||
mem_ids = [mem.mem_id for mem in top_memories]
|
||||
await self.mem_db.update_retrieval_stats(mem_ids, current_time)
|
||||
|
||||
logger.debug(f"Retrieved {len(top_memories)} memories for user {owner_id}")
|
||||
return top_memories
|
||||
|
||||
async def _merge_multiple_memories(self, facts: list[str]) -> str:
|
||||
"""Merge multiple memory facts using LLM in one call
|
||||
|
||||
Args:
|
||||
facts: List of memory facts to merge
|
||||
|
||||
Returns:
|
||||
Merged memory content
|
||||
"""
|
||||
if not self.merge_llm_provider:
|
||||
return " ".join(facts)
|
||||
|
||||
if len(facts) == 1:
|
||||
return facts[0]
|
||||
|
||||
try:
|
||||
# Format all facts as a numbered list
|
||||
facts_list = "\n".join(f"{i + 1}. {fact}" for i, fact in enumerate(facts))
|
||||
user_prompt = (
|
||||
f"Please merge the following {len(facts)} related memory entries "
|
||||
"into a single, comprehensive memory:"
|
||||
f"\n{facts_list}\n\nOutput only the merged memory content."
|
||||
)
|
||||
response = await self.merge_llm_provider.text_chat(
|
||||
prompt=user_prompt,
|
||||
system_prompt=MERGE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
merged_content = response.completion_text.strip()
|
||||
return merged_content if merged_content else " ".join(facts)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to merge memories with LLM: {e}, using fallback")
|
||||
return " ".join(facts)
|
||||
156
astrbot/core/memory/tools.py
Normal file
156
astrbot/core/memory/tools.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext, ContextWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for adding memories to user's long-term memory storage"""
|
||||
|
||||
name: str = "astr_add_memory"
|
||||
description: str = (
|
||||
"Add a new memory to the user's long-term memory storage. "
|
||||
"Use this tool only when the user explicitly asks you to remember something, "
|
||||
"or when they share stable preferences, identity, or long-term goals that will be useful in future interactions."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fact": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The concrete memory content to store, such as a user preference, "
|
||||
"identity detail, long-term goal, or stable profile fact."
|
||||
),
|
||||
},
|
||||
"memory_type": {
|
||||
"type": "string",
|
||||
"enum": ["persona", "fact", "ephemeral"],
|
||||
"description": (
|
||||
"The relative importance of this memory. "
|
||||
"Use 'persona' for core identity or highly impactful information, "
|
||||
"'fact' for normal long-term preferences, "
|
||||
"and 'ephemeral' for minor or tentative facts."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["fact", "memory_type"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Add a memory to long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Must contain 'fact' and 'memory_type'
|
||||
|
||||
Returns:
|
||||
ToolExecResult with success message
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
fact = kwargs.get("fact")
|
||||
memory_type = kwargs.get("memory_type", "fact")
|
||||
|
||||
if not fact:
|
||||
return "Missing required parameter: fact"
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Add memory using memory manager
|
||||
memory = await mm.add_memory(
|
||||
fact=fact,
|
||||
owner_id=owner_id,
|
||||
memory_type=memory_type,
|
||||
)
|
||||
|
||||
return f"Memory added successfully (ID: {memory.mem_id})"
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to add memory: {str(e)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryMemory(FunctionTool[AstrAgentContext]):
|
||||
"""Tool for querying user's long-term memories"""
|
||||
|
||||
name: str = "astr_query_memory"
|
||||
description: str = (
|
||||
"Query the user's long-term memory storage and return the most relevant memories. "
|
||||
"Use this tool when you need user-specific context, preferences, or past facts "
|
||||
"that are not explicitly present in the current conversation."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Maximum number of memories to retrieve after retention-based ranking. "
|
||||
"Typically between 3 and 10."
|
||||
),
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
"""Query memories from long-term storage
|
||||
|
||||
Args:
|
||||
context: Agent context
|
||||
**kwargs: Optional 'top_k' parameter
|
||||
|
||||
Returns:
|
||||
ToolExecResult with formatted memory list
|
||||
|
||||
"""
|
||||
mm = context.context.context.memory_manager
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
try:
|
||||
# Get owner_id from context
|
||||
owner_id = context.context.event.unified_msg_origin
|
||||
|
||||
# Query memories using memory manager
|
||||
memories = await mm.query_memory(
|
||||
owner_id=owner_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not memories:
|
||||
return "No memories found for this user."
|
||||
|
||||
# Format memories for output
|
||||
formatted_memories = []
|
||||
for i, mem in enumerate(memories, 1):
|
||||
formatted_memories.append(
|
||||
f"{i}. [{mem.memory_type.upper()}] {mem.fact} "
|
||||
f"(retrieved {mem.retrieval_count} times, "
|
||||
f"last: {mem.last_retrieval_at.strftime('%Y-%m-%d')})"
|
||||
)
|
||||
|
||||
result_text = "Retrieved memories:\n" + "\n".join(formatted_memories)
|
||||
return result_text
|
||||
|
||||
except Exception as e:
|
||||
return f"Failed to query memories: {str(e)}"
|
||||
|
||||
|
||||
ADD_MEMORY_TOOL = AddMemory()
|
||||
QUERY_MEMORY_TOOL = QueryMemory()
|
||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.star import PluginManager
|
||||
|
||||
from .context_utils import call_event_hook, call_handler, call_local_llm_tool
|
||||
from .context_utils import call_event_hook, call_handler
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,4 +15,3 @@ class PipelineContext:
|
||||
astrbot_config_id: str
|
||||
call_handler = call_handler
|
||||
call_event_hook = call_event_hook
|
||||
call_local_llm_tool = call_local_llm_tool
|
||||
|
||||
@@ -3,8 +3,6 @@ import traceback
|
||||
import typing as T
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star import star_map
|
||||
@@ -107,66 +105,3 @@ async def call_event_hook(
|
||||
return True
|
||||
|
||||
return event.is_stopped()
|
||||
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
event = context.context.event
|
||||
|
||||
try:
|
||||
if method_name == "run" or method_name == "decorator_handler":
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
elif method_name == "call":
|
||||
ready_to_call = handler(context, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
|
||||
@@ -3,20 +3,10 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import Image
|
||||
@@ -31,323 +21,19 @@ from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
|
||||
from ....astr_agent_context import AgentContextWrapper
|
||||
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from ....astr_agent_run_util import AgentRunner, run_agent
|
||||
from ....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from ....memory.tools import ADD_MEMORY_TOOL, QUERY_MEMORY_TOOL
|
||||
from ...context import PipelineContext, call_event_hook
|
||||
from ..stage import Stage
|
||||
from ..utils import inject_kb_context
|
||||
|
||||
try:
|
||||
import mcp
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
async def execute(cls, tool, run_context, **tool_args):
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
|
||||
"""
|
||||
if isinstance(tool, HandoffTool):
|
||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
elif isinstance(tool, MCPTool):
|
||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
else:
|
||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
tool: HandoffTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input", "agent")
|
||||
agent_runner = AgentRunner()
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description or "",
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=run_context.context.provider,
|
||||
first_provider_request=run_context.context.first_provider_request,
|
||||
curr_provider_request=request,
|
||||
streaming=run_context.context.streaming,
|
||||
event=run_context.context.event,
|
||||
)
|
||||
|
||||
event = run_context.context.event
|
||||
|
||||
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
||||
await event.send(
|
||||
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
|
||||
)
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=run_context.context.provider,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=run_context.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
streaming=run_context.context.streaming,
|
||||
)
|
||||
|
||||
async for _ in run_agent(agent_runner, 15, True):
|
||||
pass
|
||||
|
||||
if agent_runner.done():
|
||||
llm_response = agent_runner.get_final_llm_resp()
|
||||
|
||||
if not llm_response:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
|
||||
)
|
||||
|
||||
result = (
|
||||
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
|
||||
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
|
||||
)
|
||||
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=result,
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
event = run_context.context.event
|
||||
if not event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
|
||||
is_override_call = False
|
||||
for ty in type(tool).mro():
|
||||
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
||||
logger.debug(f"Found call in: {ty}")
|
||||
is_override_call = True
|
||||
break
|
||||
|
||||
# 检查 tool 下有没有 run 方法
|
||||
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
awaitable = None
|
||||
method_name = ""
|
||||
if tool.handler:
|
||||
awaitable = tool.handler
|
||||
method_name = "decorator_handler"
|
||||
elif is_override_call:
|
||||
awaitable = tool.call
|
||||
method_name = "call"
|
||||
elif hasattr(tool, "run"):
|
||||
awaitable = getattr(tool, "run")
|
||||
method_name = "run"
|
||||
if awaitable is None:
|
||||
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
||||
|
||||
wrapper = call_local_llm_tool(
|
||||
context=run_context,
|
||||
handler=awaitable,
|
||||
method_name=method_name,
|
||||
**tool_args,
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
anext(wrapper),
|
||||
timeout=run_context.tool_call_timeout,
|
||||
)
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
if res := run_context.context.event.get_result():
|
||||
if res.chain:
|
||||
try:
|
||||
await event.send(
|
||||
MessageChain(
|
||||
chain=res.chain,
|
||||
type="tool_direct_result",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Tool 直接发送消息失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield None
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
res = await tool.call(run_context, **tool_args)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
llm_response,
|
||||
)
|
||||
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
tool: FunctionTool[Any],
|
||||
tool_args: dict | None,
|
||||
tool_result: CallToolResult | None,
|
||||
):
|
||||
run_context.context.event.clear_result()
|
||||
|
||||
|
||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
show_tool_use: bool = True,
|
||||
) -> AsyncGenerator[MessageChain, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
resp.data["chain"].type = "tool_call_result"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use or astr_event.get_platform_name() == "webchat":
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not agent_runner.streaming:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
astr_event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
),
|
||||
)
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
elif resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
||||
if agent_runner.streaming:
|
||||
yield MessageChain().message(err_msg)
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -363,11 +49,16 @@ class LLMRequestSubStage(Stage):
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
]
|
||||
self.max_step: int = settings.get("max_agent_step", 30)
|
||||
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
self.show_reasoning = settings.get("display_reasoning_text", False)
|
||||
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -406,109 +97,77 @@ class LLMRequestSubStage(Stage):
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
async def process(
|
||||
async def _apply_kb(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
_nested: bool = False,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if self.provider_wake_prefix:
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply knowledge base context to the provider request"""
|
||||
if not self.kb_agentic_mode:
|
||||
if req.prompt is None:
|
||||
return
|
||||
try:
|
||||
kb_result = await retrieve_knowledge_base(
|
||||
query=req.prompt,
|
||||
umo=event.unified_msg_origin,
|
||||
context=self.ctx.plugin_manager.context,
|
||||
)
|
||||
if not kb_result:
|
||||
return
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
if req.system_prompt is not None:
|
||||
req.system_prompt += (
|
||||
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while retrieving knowledge base: {e}")
|
||||
else:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
async def _apply_memory(self, req: ProviderRequest):
|
||||
mm = self.ctx.plugin_manager.context.memory_manager
|
||||
if not mm or not mm._initialized:
|
||||
return
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(ADD_MEMORY_TOOL)
|
||||
req.func_tool.add_tool(QUERY_MEMORY_TOOL)
|
||||
|
||||
# 应用知识库
|
||||
try:
|
||||
await inject_kb_context(
|
||||
umo=event.unified_msg_origin,
|
||||
p_ctx=self.ctx,
|
||||
req=req,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用知识库时遇到问题: {e}")
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
# max context length
|
||||
if (
|
||||
self.max_context_length != -1 # -1 为不限制
|
||||
and len(req.contexts) // 2 > self.max_context_length
|
||||
):
|
||||
logger.debug("上下文长度超过限制,将截断。")
|
||||
req.contexts = req.contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(req.contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
req.contexts = req.contexts[index:]
|
||||
return truncated_contexts
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# fix messages
|
||||
req.contexts = self.fix_messages(req.contexts)
|
||||
|
||||
# check provider modalities
|
||||
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""检查提供商的模态能力,清理请求中的不支持内容"""
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
@@ -522,7 +181,13 @@ class LLMRequestSubStage(Stage):
|
||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
||||
)
|
||||
req.func_tool = None
|
||||
# 插件可用性设置
|
||||
|
||||
def _plugin_tool_fix(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表"""
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
@@ -536,80 +201,6 @@ class LLMRequestSubStage(Stage):
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=provider,
|
||||
first_provider_request=req,
|
||||
curr_provider_request=req,
|
||||
streaming=self.streaming_response,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=self.streaming_response,
|
||||
)
|
||||
|
||||
if self.streaming_response:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(agent_runner, self.max_step, self.show_tool_use),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_webchat(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
@@ -657,9 +248,6 @@ class LLMRequestSubStage(Stage):
|
||||
),
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
logger.debug(
|
||||
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
|
||||
)
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
@@ -687,6 +275,9 @@ class LLMRequestSubStage(Stage):
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
if req.contexts is None:
|
||||
req.contexts = []
|
||||
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
@@ -706,7 +297,7 @@ class LLMRequestSubStage(Stage):
|
||||
history=messages,
|
||||
)
|
||||
|
||||
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
@@ -721,3 +312,187 @@ class LLMRequestSubStage(Stage):
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
_nested: bool = False,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if self.provider_wake_prefix and not event.message_str.startswith(
|
||||
self.provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# apply memory feature
|
||||
await self._apply_memory(req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,23 +1,64 @@
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolExecResult
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
async def inject_kb_context(
|
||||
@dataclass
|
||||
class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "astr_kb_search"
|
||||
description: str = (
|
||||
"Query the knowledge base for facts or relevant context. "
|
||||
"Use this tool when the user's question requires factual information, "
|
||||
"definitions, background knowledge, or previously indexed content. "
|
||||
"Only send short keywords or a concise question as the query."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "A concise keyword query for the knowledge base.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, context: ContextWrapper[AstrAgentContext], **kwargs
|
||||
) -> ToolExecResult:
|
||||
query = kwargs.get("query", "")
|
||||
if not query:
|
||||
return "error: Query parameter is empty."
|
||||
result = await retrieve_knowledge_base(
|
||||
query=kwargs.get("query", ""),
|
||||
umo=context.context.event.unified_msg_origin,
|
||||
context=context.context.context,
|
||||
)
|
||||
if not result:
|
||||
return "No relevant knowledge found."
|
||||
return result
|
||||
|
||||
|
||||
async def retrieve_knowledge_base(
|
||||
query: str,
|
||||
umo: str,
|
||||
p_ctx: PipelineContext,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
context: Context,
|
||||
) -> str | None:
|
||||
"""Inject knowledge base context into the provider request
|
||||
|
||||
Args:
|
||||
umo: Unique message object (session ID)
|
||||
p_ctx: Pipeline context
|
||||
req: Provider request
|
||||
|
||||
"""
|
||||
kb_mgr = p_ctx.plugin_manager.context.kb_manager
|
||||
kb_mgr = context.kb_manager
|
||||
config = context.get_config(umo=umo)
|
||||
|
||||
# 1. 优先读取会话级配置
|
||||
session_config = await sp.session_get(umo, "kb_config", default={})
|
||||
@@ -54,18 +95,18 @@ async def inject_kb_context(
|
||||
|
||||
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
|
||||
else:
|
||||
kb_names = p_ctx.astrbot_config.get("kb_names", [])
|
||||
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
|
||||
kb_names = config.get("kb_names", [])
|
||||
top_k = config.get("kb_final_top_k", 5)
|
||||
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
|
||||
|
||||
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
|
||||
top_k_fusion = config.get("kb_fusion_top_k", 20)
|
||||
|
||||
if not kb_names:
|
||||
return
|
||||
|
||||
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
|
||||
kb_context = await kb_mgr.retrieve(
|
||||
query=req.prompt,
|
||||
query=query,
|
||||
kb_names=kb_names,
|
||||
top_k_fusion=top_k_fusion,
|
||||
top_m_final=top_k,
|
||||
@@ -78,4 +119,7 @@ async def inject_kb_context(
|
||||
if formatted:
|
||||
results = kb_context.get("results", [])
|
||||
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
|
||||
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"
|
||||
return formatted
|
||||
|
||||
|
||||
KNOWLEDGE_BASE_QUERY_TOOL = KnowledgeBaseQueryTool()
|
||||
|
||||
@@ -10,7 +10,6 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from ..context import PipelineContext, call_event_hook
|
||||
from ..stage import Stage, register_stage
|
||||
@@ -169,12 +168,15 @@ class RespondStage(Stage):
|
||||
logger.warning("async_stream 为空,跳过发送。")
|
||||
return
|
||||
# 流式结果直接交付平台适配器处理
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented",
|
||||
False,
|
||||
realtime_segmenting = (
|
||||
self.config.get("provider_settings", {}).get(
|
||||
"unsupported_streaming_strategy",
|
||||
"realtime_segmenting",
|
||||
)
|
||||
== "realtime_segmenting"
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_id()})")
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
await event.send_streaming(result.async_stream, realtime_segmenting)
|
||||
return
|
||||
if len(result.chain) > 0:
|
||||
# 检查路径映射
|
||||
@@ -218,21 +220,20 @@ class RespondStage(Stage):
|
||||
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
|
||||
)
|
||||
return
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
if comp.type in need_separately:
|
||||
await event.send(MessageChain([comp]))
|
||||
else:
|
||||
await event.send(MessageChain([*header_comps, comp]))
|
||||
header_comps.clear()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
for comp in result.chain:
|
||||
i = await self._calc_comp_interval(comp)
|
||||
await asyncio.sleep(i)
|
||||
try:
|
||||
if comp.type in need_separately:
|
||||
await event.send(MessageChain([comp]))
|
||||
else:
|
||||
await event.send(MessageChain([*header_comps, comp]))
|
||||
header_comps.clear()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
if all(
|
||||
comp.type in {ComponentType.Reply, ComponentType.At}
|
||||
|
||||
@@ -16,3 +16,6 @@ class PlatformMetadata:
|
||||
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||
logo_path: str | None = None
|
||||
"""平台适配器的 logo 文件路径(相对于插件目录)"""
|
||||
|
||||
support_streaming_message: bool = True
|
||||
"""平台是否支持真实流式传输"""
|
||||
|
||||
@@ -14,6 +14,7 @@ def register_platform_adapter(
|
||||
default_config_tmpl: dict | None = None,
|
||||
adapter_display_name: str | None = None,
|
||||
logo_path: str | None = None,
|
||||
support_streaming_message: bool = True,
|
||||
):
|
||||
"""用于注册平台适配器的带参装饰器。
|
||||
|
||||
@@ -42,6 +43,7 @@ def register_platform_adapter(
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
support_streaming_message=support_streaming_message,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -29,6 +29,7 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
@register_platform_adapter(
|
||||
"aiocqhttp",
|
||||
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
support_streaming_message=False,
|
||||
)
|
||||
class AiocqhttpAdapter(Platform):
|
||||
def __init__(
|
||||
@@ -49,6 +50,7 @@ class AiocqhttpAdapter(Platform):
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
self.bot = CQHttp(
|
||||
|
||||
@@ -37,7 +37,9 @@ class MyEventHandler(dingtalk_stream.EventHandler):
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
|
||||
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
|
||||
@register_platform_adapter(
|
||||
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
|
||||
)
|
||||
class DingtalkPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -74,6 +76,14 @@ class DingtalkPlatformAdapter(Platform):
|
||||
)
|
||||
self.client_ = client # 用于 websockets 的 client
|
||||
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
|
||||
if not dingtalk_id:
|
||||
return dingtalk_id
|
||||
prefix = "$:LWCP_v1:$"
|
||||
if dingtalk_id.startswith(prefix):
|
||||
return dingtalk_id[len(prefix) :]
|
||||
return dingtalk_id
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
@@ -86,6 +96,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(
|
||||
@@ -102,10 +113,10 @@ class DingtalkPlatformAdapter(Platform):
|
||||
else MessageType.FRIEND_MESSAGE
|
||||
)
|
||||
abm.sender = MessageMember(
|
||||
user_id=message.sender_id,
|
||||
user_id=self._id_to_sid(message.sender_id),
|
||||
nickname=message.sender_nick,
|
||||
)
|
||||
abm.self_id = message.chatbot_user_id
|
||||
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
|
||||
@@ -113,8 +124,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
# 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
|
||||
if message.at_users:
|
||||
for user in message.at_users:
|
||||
if user.dingtalk_id:
|
||||
abm.message.append(At(qq=user.dingtalk_id))
|
||||
if id := self._id_to_sid(user.dingtalk_id):
|
||||
abm.message.append(At(qq=id))
|
||||
abm.group_id = message.conversation_id
|
||||
if self.unique_session:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
@@ -34,7 +34,9 @@ else:
|
||||
|
||||
|
||||
# 注册平台适配器
|
||||
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
|
||||
@register_platform_adapter(
|
||||
"discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
|
||||
)
|
||||
class DiscordPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -111,6 +113,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
default_config_tmpl=self.config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import sys
|
||||
from collections.abc import AsyncGenerator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,11 +21,6 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata
|
||||
from .client import DiscordBotClient
|
||||
from .components import DiscordEmbed, DiscordView
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# 自定义Discord视图组件(兼容旧版本)
|
||||
class DiscordViewComponent(BaseMessageComponent):
|
||||
@@ -49,7 +44,6 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.client = client
|
||||
self.interaction_followup_webhook = interaction_followup_webhook
|
||||
|
||||
@override
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息到Discord平台"""
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
@@ -98,6 +92,21 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return None
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _get_channel(self) -> discord.abc.Messageable | None:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
|
||||
@@ -23,7 +23,9 @@ from ...register import register_platform_adapter
|
||||
from .lark_event import LarkMessageEvent
|
||||
|
||||
|
||||
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
|
||||
@register_platform_adapter(
|
||||
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
|
||||
)
|
||||
class LarkPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -115,6 +117,7 @@ class LarkPlatformAdapter(Platform):
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
|
||||
@@ -45,7 +45,9 @@ MAX_FILE_UPLOAD_COUNT = 16
|
||||
DEFAULT_UPLOAD_CONCURRENCY = 3
|
||||
|
||||
|
||||
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
||||
@register_platform_adapter(
|
||||
"misskey", "Misskey 平台适配器", support_streaming_message=False
|
||||
)
|
||||
class MisskeyPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -120,6 +122,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
description="Misskey 平台适配器",
|
||||
id=self.config.get("id", "misskey"),
|
||||
default_config_tmpl=default_config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
@@ -29,8 +29,7 @@ from astrbot.core.platform.astr_message_event import MessageSession
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"satori",
|
||||
"Satori 协议适配器",
|
||||
"satori", "Satori 协议适配器", support_streaming_message=False
|
||||
)
|
||||
class SatoriPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
@@ -60,6 +59,7 @@ class SatoriPlatformAdapter(Platform):
|
||||
name="satori",
|
||||
description="Satori 通用协议适配器",
|
||||
id=self.config["id"],
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
self.ws: ClientConnection | None = None
|
||||
|
||||
@@ -30,6 +30,7 @@ from .slack_event import SlackMessageEvent
|
||||
@register_platform_adapter(
|
||||
"slack",
|
||||
"适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
support_streaming_message=False,
|
||||
)
|
||||
class SlackAdapter(Platform):
|
||||
def __init__(
|
||||
@@ -68,6 +69,7 @@ class SlackAdapter(Platform):
|
||||
name="slack",
|
||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
id=self.config.get("id"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
# 初始化 Slack Web Client
|
||||
|
||||
@@ -163,6 +163,9 @@ class WebChatAdapter(Platform):
|
||||
_, _, payload = message.raw_message # type: ignore
|
||||
message_event.set_extra("selected_provider", payload.get("selected_provider"))
|
||||
message_event.set_extra("selected_model", payload.get("selected_model"))
|
||||
message_event.set_extra(
|
||||
"enable_streaming", payload.get("enable_streaming", True)
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
reasoning_content = ""
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
async for chain in generator:
|
||||
@@ -124,16 +125,22 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
|
||||
r = await WebChatMessageEvent._send(
|
||||
chain,
|
||||
session_id=self.session_id,
|
||||
streaming=True,
|
||||
)
|
||||
if chain.type == "reasoning":
|
||||
reasoning_content += chain.get_plain_text()
|
||||
else:
|
||||
final_data += r
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "complete", # complete means we return the final result
|
||||
"data": final_data,
|
||||
"reasoning": reasoning_content,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
|
||||
@@ -32,7 +32,9 @@ except ImportError as e:
|
||||
)
|
||||
|
||||
|
||||
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
|
||||
@register_platform_adapter(
|
||||
"wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
|
||||
)
|
||||
class WeChatPadProAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -51,6 +53,7 @@ class WeChatPadProAdapter(Platform):
|
||||
name="wechatpadpro",
|
||||
description="WeChatPadPro 消息平台适配器",
|
||||
id=self.config.get("id", "wechatpadpro"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
# 保存配置信息
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
@@ -50,6 +51,21 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
await self._send_voice(session, comp)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return None
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
||||
b64 = await comp.convert_to_base64()
|
||||
raw = self._validate_base64(b64)
|
||||
|
||||
@@ -110,7 +110,7 @@ class WecomServer:
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
|
||||
@register_platform_adapter("wecom", "wecom 适配器")
|
||||
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
|
||||
class WecomPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -196,6 +196,7 @@ class WecomPlatformAdapter(Platform):
|
||||
"wecom",
|
||||
"wecom 适配器",
|
||||
id=self.config.get("id", "wecom"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -30,7 +30,7 @@ from .wecomai_api import (
|
||||
WecomAIBotStreamMessageBuilder,
|
||||
)
|
||||
from .wecomai_event import WecomAIBotMessageEvent
|
||||
from .wecomai_queue_mgr import WecomAIQueueMgr, wecomai_queue_mgr
|
||||
from .wecomai_queue_mgr import WecomAIQueueMgr
|
||||
from .wecomai_server import WecomAIBotServer
|
||||
from .wecomai_utils import (
|
||||
WecomAIBotConstants,
|
||||
@@ -144,9 +144,12 @@ class WecomAIBotAdapter(Platform):
|
||||
# 事件循环和关闭信号
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
# 队列管理器
|
||||
self.queue_mgr = WecomAIQueueMgr()
|
||||
|
||||
# 队列监听器
|
||||
self.queue_listener = WecomAIQueueListener(
|
||||
wecomai_queue_mgr,
|
||||
self.queue_mgr,
|
||||
self._handle_queued_message,
|
||||
)
|
||||
|
||||
@@ -189,7 +192,7 @@ class WecomAIBotAdapter(Platform):
|
||||
stream_id,
|
||||
session_id,
|
||||
)
|
||||
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
|
||||
self.queue_mgr.set_pending_response(stream_id, callback_params)
|
||||
|
||||
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
stream_id,
|
||||
@@ -207,7 +210,7 @@ class WecomAIBotAdapter(Platform):
|
||||
elif msgtype == "stream":
|
||||
# wechat server is requesting for updates of a stream
|
||||
stream_id = message_data["stream"]["id"]
|
||||
if not wecomai_queue_mgr.has_back_queue(stream_id):
|
||||
if not self.queue_mgr.has_back_queue(stream_id):
|
||||
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
|
||||
# 返回结束标志,告诉微信服务器流已结束
|
||||
@@ -222,7 +225,7 @@ class WecomAIBotAdapter(Platform):
|
||||
callback_params["timestamp"],
|
||||
)
|
||||
return resp
|
||||
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
queue = self.queue_mgr.get_or_create_back_queue(stream_id)
|
||||
if queue.empty():
|
||||
logger.debug(
|
||||
f"No new messages in back queue for stream_id: {stream_id}",
|
||||
@@ -242,10 +245,9 @@ class WecomAIBotAdapter(Platform):
|
||||
elif msg["type"] == "end":
|
||||
# stream end
|
||||
finish = True
|
||||
wecomai_queue_mgr.remove_queues(stream_id)
|
||||
self.queue_mgr.remove_queues(stream_id)
|
||||
break
|
||||
else:
|
||||
pass
|
||||
|
||||
logger.debug(
|
||||
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}",
|
||||
)
|
||||
@@ -313,8 +315,8 @@ class WecomAIBotAdapter(Platform):
|
||||
session_id: str,
|
||||
):
|
||||
"""将消息放入队列进行异步处理"""
|
||||
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
|
||||
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
input_queue = self.queue_mgr.get_or_create_queue(stream_id)
|
||||
_ = self.queue_mgr.get_or_create_back_queue(stream_id)
|
||||
message_payload = {
|
||||
"message_data": message_data,
|
||||
"callback_params": callback_params,
|
||||
@@ -453,6 +455,7 @@ class WecomAIBotAdapter(Platform):
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
api_client=self.api_client,
|
||||
queue_mgr=self.queue_mgr,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
@@ -8,7 +8,7 @@ from astrbot.api.message_components import (
|
||||
)
|
||||
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_queue_mgr import wecomai_queue_mgr
|
||||
from .wecomai_queue_mgr import WecomAIQueueMgr
|
||||
|
||||
|
||||
class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
@@ -21,6 +21,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
platform_meta,
|
||||
session_id: str,
|
||||
api_client: WecomAIBotAPIClient,
|
||||
queue_mgr: WecomAIQueueMgr,
|
||||
):
|
||||
"""初始化消息事件
|
||||
|
||||
@@ -34,14 +35,16 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
"""
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.api_client = api_client
|
||||
self.queue_mgr = queue_mgr
|
||||
|
||||
@staticmethod
|
||||
async def _send(
|
||||
message_chain: MessageChain,
|
||||
stream_id: str,
|
||||
queue_mgr: WecomAIQueueMgr,
|
||||
streaming: bool = False,
|
||||
):
|
||||
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
back_queue = queue_mgr.get_or_create_back_queue(stream_id)
|
||||
|
||||
if not message_chain:
|
||||
await back_queue.put(
|
||||
@@ -94,7 +97,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
"wecom_ai_bot platform event raw_message should be a dict"
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
await WecomAIBotMessageEvent._send(message, stream_id)
|
||||
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback=False):
|
||||
@@ -105,7 +108,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
"wecom_ai_bot platform event raw_message should be a dict"
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
|
||||
|
||||
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
|
||||
increment_plain = ""
|
||||
@@ -134,6 +137,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
final_data += await WecomAIBotMessageEvent._send(
|
||||
chain,
|
||||
stream_id=stream_id,
|
||||
queue_mgr=self.queue_mgr,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,7 +151,3 @@ class WecomAIQueueMgr:
|
||||
"output_queues": len(self.back_queues),
|
||||
"pending_responses": len(self.pending_responses),
|
||||
}
|
||||
|
||||
|
||||
# 全局队列管理器实例
|
||||
wecomai_queue_mgr = WecomAIQueueMgr()
|
||||
|
||||
@@ -113,7 +113,9 @@ class WecomServer:
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
|
||||
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
|
||||
@register_platform_adapter(
|
||||
"weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
|
||||
)
|
||||
class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -195,6 +197,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
"weixin_official_account",
|
||||
"微信公众平台 适配器",
|
||||
id=self.config.get("id", "weixin_official_account"),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .entities import ProviderMetaData
|
||||
from .provider import Personality, Provider, STTProvider
|
||||
from .provider import Provider, STTProvider
|
||||
|
||||
__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]
|
||||
__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
|
||||
|
||||
@@ -30,18 +30,31 @@ class ProviderType(enum.Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData:
|
||||
type: str
|
||||
"""提供商适配器名称,如 openai, ollama"""
|
||||
desc: str = ""
|
||||
"""提供商适配器描述"""
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Any = None
|
||||
class ProviderMeta:
|
||||
"""The basic metadata of a provider instance."""
|
||||
|
||||
id: str
|
||||
"""the unique id of the provider instance that user configured"""
|
||||
model: str | None
|
||||
"""the model name of the provider instance currently used"""
|
||||
type: str
|
||||
"""the name of the provider adapter, such as openai, ollama"""
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
"""the capability type of the provider adapter"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData(ProviderMeta):
|
||||
"""The metadata of a provider adapter for registration."""
|
||||
|
||||
desc: str = ""
|
||||
"""the short description of the provider adapter"""
|
||||
cls_type: Any = None
|
||||
"""the class type of the provider adapter"""
|
||||
default_config_tmpl: dict | None = None
|
||||
"""平台的默认配置模板"""
|
||||
"""the default configuration template of the provider adapter"""
|
||||
provider_display_name: str | None = None
|
||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,12 +73,20 @@ class ToolCallsResult:
|
||||
]
|
||||
return ret
|
||||
|
||||
def to_openai_messages_model(
|
||||
self,
|
||||
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
|
||||
return [
|
||||
self.tool_calls_info,
|
||||
*self.tool_calls_result,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest:
|
||||
prompt: str
|
||||
prompt: str | None = None
|
||||
"""提示词"""
|
||||
session_id: str = ""
|
||||
session_id: str | None = ""
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
@@ -181,25 +202,30 @@ class ProviderRequest:
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
"""角色, assistant, tool, err"""
|
||||
"""The role of the message, e.g., assistant, tool, err"""
|
||||
result_chain: MessageChain | None = None
|
||||
"""返回的消息链"""
|
||||
"""A chain of message components representing the text completion from LLM."""
|
||||
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
|
||||
"""工具调用参数"""
|
||||
"""Tool call arguments."""
|
||||
tools_call_name: list[str] = field(default_factory=list)
|
||||
"""工具调用名称"""
|
||||
"""Tool call names."""
|
||||
tools_call_ids: list[str] = field(default_factory=list)
|
||||
"""工具调用 ID"""
|
||||
"""Tool call IDs."""
|
||||
tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
"""Tool call extra content. tool_call_id -> extra_content dict"""
|
||||
reasoning_content: str = ""
|
||||
"""The reasoning content extracted from the LLM, if any."""
|
||||
|
||||
raw_completion: (
|
||||
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
|
||||
) = None
|
||||
_new_record: dict[str, Any] | None = None
|
||||
"""The raw completion response from the LLM provider."""
|
||||
|
||||
_completion_text: str = ""
|
||||
"""The plain text of the completion."""
|
||||
|
||||
is_chunk: bool = False
|
||||
"""是否是流式输出的单个 Chunk"""
|
||||
"""Indicates if the response is a chunked response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -209,11 +235,11 @@ class LLMResponse:
|
||||
tools_call_args: list[dict[str, Any]] | None = None,
|
||||
tools_call_name: list[str] | None = None,
|
||||
tools_call_ids: list[str] | None = None,
|
||||
tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
|
||||
raw_completion: ChatCompletion
|
||||
| GenerateContentResponse
|
||||
| AnthropicMessage
|
||||
| None = None,
|
||||
_new_record: dict[str, Any] | None = None,
|
||||
is_chunk: bool = False,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
@@ -233,6 +259,8 @@ class LLMResponse:
|
||||
tools_call_name = []
|
||||
if tools_call_ids is None:
|
||||
tools_call_ids = []
|
||||
if tools_call_extra_content is None:
|
||||
tools_call_extra_content = {}
|
||||
|
||||
self.role = role
|
||||
self.completion_text = completion_text
|
||||
@@ -240,8 +268,8 @@ class LLMResponse:
|
||||
self.tools_call_args = tools_call_args
|
||||
self.tools_call_name = tools_call_name
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.tools_call_extra_content = tools_call_extra_content
|
||||
self.raw_completion = raw_completion
|
||||
self._new_record = _new_record
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
@property
|
||||
@@ -266,16 +294,19 @@ class LLMResponse:
|
||||
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
|
||||
ret = []
|
||||
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
||||
ret.append(
|
||||
{
|
||||
"id": self.tools_call_ids[idx],
|
||||
"function": {
|
||||
"name": self.tools_call_name[idx],
|
||||
"arguments": json.dumps(tool_call_arg),
|
||||
},
|
||||
"type": "function",
|
||||
payload = {
|
||||
"id": self.tools_call_ids[idx],
|
||||
"function": {
|
||||
"name": self.tools_call_name[idx],
|
||||
"arguments": json.dumps(tool_call_arg),
|
||||
},
|
||||
)
|
||||
"type": "function",
|
||||
}
|
||||
if self.tools_call_extra_content.get(self.tools_call_ids[idx]):
|
||||
payload["extra_content"] = self.tools_call_extra_content[
|
||||
self.tools_call_ids[idx]
|
||||
]
|
||||
ret.append(payload)
|
||||
return ret
|
||||
|
||||
def to_openai_to_calls_model(self) -> list[ToolCall]:
|
||||
@@ -289,6 +320,10 @@ class LLMResponse:
|
||||
name=self.tools_call_name[idx],
|
||||
arguments=json.dumps(tool_call_arg),
|
||||
),
|
||||
# the extra_content will not serialize if it's None when calling ToolCall.model_dump()
|
||||
extra_content=self.tools_call_extra_content.get(
|
||||
self.tools_call_ids[idx]
|
||||
),
|
||||
),
|
||||
)
|
||||
return ret
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
@@ -24,7 +25,16 @@ SUPPORTED_TYPES = [
|
||||
"boolean",
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
PY_TO_JSON_TYPE = {
|
||||
"int": "number",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"str": "string",
|
||||
"dict": "object",
|
||||
"list": "array",
|
||||
"tuple": "array",
|
||||
"set": "array",
|
||||
}
|
||||
# alias
|
||||
FuncTool = FunctionTool
|
||||
|
||||
@@ -106,7 +116,7 @@ class FunctionToolManager:
|
||||
def spec_to_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list,
|
||||
func_args: list[dict],
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
) -> FuncTool:
|
||||
@@ -115,10 +125,9 @@ class FunctionToolManager:
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
p = copy.deepcopy(param)
|
||||
p.pop("name", None)
|
||||
params["properties"][param["name"]] = p
|
||||
return FuncTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
@@ -271,19 +280,22 @@ class FunctionToolManager:
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
if name in self.mcp_client_dict:
|
||||
client = self.mcp_client_dict[name]
|
||||
try:
|
||||
# 关闭MCP连接
|
||||
await self.mcp_client_dict[name].cleanup()
|
||||
self.mcp_client_dict.pop(name)
|
||||
await client.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
finally:
|
||||
# Remove client from dict after cleanup attempt (successful or not)
|
||||
self.mcp_client_dict.pop(name, None)
|
||||
# 移除关联的FuncTool
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
@staticmethod
|
||||
async def test_mcp_server_connection(config: dict) -> list[str]:
|
||||
|
||||
@@ -241,6 +241,8 @@ class ProviderManager:
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
@@ -354,6 +356,8 @@ class ProviderManager:
|
||||
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||||
return
|
||||
|
||||
provider_metadata.id = provider_config["id"]
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
@@ -394,7 +398,6 @@ class ProviderManager:
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.selected_default_persona,
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
|
||||
@@ -1,28 +1,18 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.db.po import Personality
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderType,
|
||||
ProviderMeta,
|
||||
RerankResult,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta:
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
provider_type: ProviderType
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
"""Provider Abstract Class"""
|
||||
|
||||
@@ -43,15 +33,15 @@ class AbstractProvider(abc.ABC):
|
||||
"""Get the provider metadata"""
|
||||
provider_type_name = self.provider_config["type"]
|
||||
meta_data = provider_cls_map.get(provider_type_name)
|
||||
provider_type = meta_data.provider_type if meta_data else None
|
||||
if provider_type is None:
|
||||
raise ValueError(f"Cannot find provider type: {provider_type_name}")
|
||||
return ProviderMeta(
|
||||
id=self.provider_config["id"],
|
||||
if not meta_data:
|
||||
raise ValueError(f"Provider type {provider_type_name} not registered")
|
||||
meta = ProviderMeta(
|
||||
id=self.provider_config.get("id", "default"),
|
||||
model=self.get_model(),
|
||||
type=provider_type_name,
|
||||
provider_type=provider_type,
|
||||
provider_type=meta_data.provider_type,
|
||||
)
|
||||
return meta
|
||||
|
||||
|
||||
class Provider(AbstractProvider):
|
||||
@@ -61,15 +51,10 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality = default_persona
|
||||
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -36,6 +36,8 @@ def register_provider_adapter(
|
||||
default_config_tmpl["id"] = provider_type_name
|
||||
|
||||
pm = ProviderMetaData(
|
||||
id="default", # will be replaced when instantiated
|
||||
model=None,
|
||||
type=provider_type_name,
|
||||
desc=desc,
|
||||
provider_type=provider_type,
|
||||
|
||||
@@ -25,12 +25,10 @@ class ProviderAnthropic(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
self.chosen_api_key: str = ""
|
||||
|
||||
@@ -20,12 +20,10 @@ class ProviderCoze(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -8,7 +8,7 @@ from dashscope.app.application_response import ApplicationResponse
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from .. import Personality, Provider
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
@@ -20,13 +20,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -18,12 +18,10 @@ class ProviderDify(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -53,12 +53,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: list = super().get_keys()
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
@@ -292,13 +290,24 @@ class ProviderGoogleGenAI(Provider):
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
elif not native_tool_enabled and "tool_calls" in message:
|
||||
parts = [
|
||||
types.Part.from_function_call(
|
||||
parts = []
|
||||
for tool in message["tool_calls"]:
|
||||
part = types.Part.from_function_call(
|
||||
name=tool["function"]["name"],
|
||||
args=json.loads(tool["function"]["arguments"]),
|
||||
)
|
||||
for tool in message["tool_calls"]
|
||||
]
|
||||
# we should set thought_signature back to part if exists
|
||||
# for more info about thought_signature, see:
|
||||
# https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
if "extra_content" in tool and tool["extra_content"]:
|
||||
ts_bs64 = (
|
||||
tool["extra_content"]
|
||||
.get("google", {})
|
||||
.get("thought_signature")
|
||||
)
|
||||
if ts_bs64:
|
||||
part.thought_signature = base64.b64decode(ts_bs64)
|
||||
parts.append(part)
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
else:
|
||||
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
||||
@@ -326,8 +335,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
return gemini_contents
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
|
||||
"""Extract reasoning content from candidate parts"""
|
||||
if not candidate.content or not candidate.content.parts:
|
||||
return ""
|
||||
|
||||
thought_buf: list[str] = [
|
||||
(p.text or "") for p in candidate.content.parts if p.thought
|
||||
]
|
||||
return "".join(thought_buf).strip()
|
||||
|
||||
def _process_content_parts(
|
||||
self,
|
||||
candidate: types.Candidate,
|
||||
llm_response: LLMResponse,
|
||||
) -> MessageChain:
|
||||
@@ -358,6 +377,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
|
||||
raise Exception("API 返回的 candidate.content.parts 为空。")
|
||||
|
||||
# 提取 reasoning content
|
||||
reasoning = self._extract_reasoning_content(candidate)
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
|
||||
chain = []
|
||||
part: types.Part
|
||||
|
||||
@@ -380,10 +404,15 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_name.append(part.function_call.name)
|
||||
llm_response.tools_call_args.append(part.function_call.args)
|
||||
# gemini 返回的 function_call.id 可能为 None
|
||||
llm_response.tools_call_ids.append(
|
||||
part.function_call.id or part.function_call.name,
|
||||
)
|
||||
# function_call.id might be None, use name as fallback
|
||||
tool_call_id = part.function_call.id or part.function_call.name
|
||||
llm_response.tools_call_ids.append(tool_call_id)
|
||||
# extra_content
|
||||
if part.thought_signature:
|
||||
ts_bs64 = base64.b64encode(part.thought_signature).decode("utf-8")
|
||||
llm_response.tools_call_extra_content[tool_call_id] = {
|
||||
"google": {"thought_signature": ts_bs64}
|
||||
}
|
||||
elif (
|
||||
part.inline_data
|
||||
and part.inline_data.mime_type
|
||||
@@ -422,6 +451,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
contents=conversation,
|
||||
config=config,
|
||||
)
|
||||
logger.debug(f"genai result: {result}")
|
||||
|
||||
if not result.candidates:
|
||||
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
|
||||
@@ -515,6 +545,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
# Accumulate the complete response text for the final response
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
final_response = None
|
||||
|
||||
async for chunk in result:
|
||||
@@ -539,9 +570,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
yield llm_response
|
||||
return
|
||||
|
||||
_f = False
|
||||
|
||||
# 提取 reasoning content
|
||||
reasoning = self._extract_reasoning_content(chunk.candidates[0])
|
||||
if reasoning:
|
||||
_f = True
|
||||
accumulated_reasoning += reasoning
|
||||
llm_response.reasoning_content = reasoning
|
||||
if chunk.text:
|
||||
_f = True
|
||||
accumulated_text += chunk.text
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
if _f:
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
@@ -559,6 +600,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
if not final_response:
|
||||
final_response = LLMResponse("assistant", is_chunk=False)
|
||||
|
||||
# Set the complete accumulated reasoning in the final response
|
||||
if accumulated_reasoning:
|
||||
final_response.reasoning_content = accumulated_reasoning
|
||||
|
||||
# Set the complete accumulated text in the final response
|
||||
if accumulated_text:
|
||||
final_response.result_chain = MessageChain(
|
||||
|
||||
15
astrbot/core/provider/sources/groq_source.py
Normal file
15
astrbot/core/provider/sources/groq_source.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"groq_chat_completion", "Groq Chat Completion Provider Adapter"
|
||||
)
|
||||
class ProviderGroq(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.reasoning_key = "reasoning"
|
||||
@@ -4,12 +4,14 @@ import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai._exceptions import NotFoundError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
@@ -28,37 +30,37 @@ from ..register import register_provider_adapter
|
||||
"OpenAI API Chat Completion 提供商适配器",
|
||||
)
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
def __init__(self, provider_config, provider_settings) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: list = super().get_keys()
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
self.custom_headers = provider_config.get("custom_headers", {})
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
# 适配 azure openai #332
|
||||
|
||||
if not isinstance(self.custom_headers, dict) or not self.custom_headers:
|
||||
self.custom_headers = None
|
||||
else:
|
||||
for key in self.custom_headers:
|
||||
self.custom_headers[key] = str(self.custom_headers[key])
|
||||
|
||||
if "api_version" in provider_config:
|
||||
# 使用 azure api
|
||||
# Using Azure OpenAI API
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
default_headers=self.custom_headers,
|
||||
base_url=provider_config.get("api_base", ""),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
else:
|
||||
# 使用 openai api
|
||||
# Using OpenAI Official API
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
default_headers=self.custom_headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
@@ -70,6 +72,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
model = model_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
|
||||
self.reasoning_key = "reasoning_content"
|
||||
|
||||
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
|
||||
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
||||
|
||||
@@ -147,7 +151,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
llm_response = await self.parse_openai_completion(completion, tools)
|
||||
llm_response = await self._parse_openai_completion(completion, tools)
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -200,39 +204,82 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
# 处理文本内容
|
||||
# logger.debug(f"chunk delta: {delta}")
|
||||
# handle the content delta
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
_y = False
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
_y = True
|
||||
if delta.content:
|
||||
completion_text = delta.content
|
||||
llm_response.result_chain = MessageChain(
|
||||
chain=[Comp.Plain(completion_text)],
|
||||
)
|
||||
_y = True
|
||||
if _y:
|
||||
yield llm_response
|
||||
|
||||
final_completion = state.get_final_completion()
|
||||
llm_response = await self.parse_openai_completion(final_completion, tools)
|
||||
llm_response = await self._parse_openai_completion(final_completion, tools)
|
||||
|
||||
yield llm_response
|
||||
|
||||
async def parse_openai_completion(
|
||||
def _extract_reasoning_content(
|
||||
self,
|
||||
completion: ChatCompletion | ChatCompletionChunk,
|
||||
) -> str:
|
||||
"""Extract reasoning content from OpenAI ChatCompletion if available."""
|
||||
reasoning_text = ""
|
||||
if len(completion.choices) == 0:
|
||||
return reasoning_text
|
||||
if isinstance(completion, ChatCompletion):
|
||||
choice = completion.choices[0]
|
||||
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
elif isinstance(completion, ChatCompletionChunk):
|
||||
delta = completion.choices[0].delta
|
||||
reasoning_attr = getattr(delta, self.reasoning_key, None)
|
||||
if reasoning_attr:
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
|
||||
async def _parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: ToolSet | None
|
||||
) -> LLMResponse:
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
"""Parse OpenAI ChatCompletion into LLMResponse"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
# parse the text completion
|
||||
if choice.message.content is not None:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
# specially, some providers may set <think> tags around reasoning content in the completion text,
|
||||
# we use regex to remove them, and store then in reasoning_content field
|
||||
reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
matches = reasoning_pattern.findall(completion_text)
|
||||
if matches:
|
||||
llm_response.reasoning_content = "\n".join(
|
||||
[match.strip() for match in matches],
|
||||
)
|
||||
completion_text = reasoning_pattern.sub("", completion_text).strip()
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
# parse the reasoning content if any
|
||||
# the priority is higher than the <think> tag extraction
|
||||
llm_response.reasoning_content = self._extract_reasoning_content(completion)
|
||||
|
||||
# parse tool calls if any
|
||||
if choice.message.tool_calls and tools is not None:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
tool_call_ids = []
|
||||
tool_call_extra_content_dict = {}
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if isinstance(tool_call, str):
|
||||
# workaround for #1359
|
||||
@@ -250,16 +297,21 @@ class ProviderOpenAIOfficial(Provider):
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
tool_call_ids.append(tool_call.id)
|
||||
|
||||
# gemini-2.5 / gemini-3 series extra_content handling
|
||||
extra_content = getattr(tool_call, "extra_content", None)
|
||||
if extra_content is not None:
|
||||
tool_call_extra_content_dict[tool_call.id] = extra_content
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
llm_response.tools_call_ids = tool_call_ids
|
||||
|
||||
llm_response.tools_call_extra_content = tool_call_extra_content_dict
|
||||
# specially handle finish reason
|
||||
if choice.finish_reason == "content_filter":
|
||||
raise Exception(
|
||||
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
|
||||
)
|
||||
|
||||
if llm_response.completion_text is None and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
@@ -307,7 +359,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
# xAI 原生搜索参数(最小侵入地在此处注入)
|
||||
# xAI origin search tool inject
|
||||
self._maybe_inject_xai_search(payloads, **kwargs)
|
||||
|
||||
return payloads, context_query
|
||||
@@ -429,12 +481,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.client.api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
@@ -499,12 +545,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
@@ -600,4 +640,3 @@ class ProviderOpenAIOfficial(Provider):
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ""
|
||||
|
||||
@@ -12,10 +12,5 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
@@ -5,18 +5,23 @@ from typing import Any
|
||||
|
||||
from deprecated import deprecated
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.memory.memory_manager import MemoryManager
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType
|
||||
from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.provider.provider import (
|
||||
@@ -31,6 +36,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
)
|
||||
|
||||
from ..exceptions import ProviderNotFoundError
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
@@ -60,6 +66,7 @@ class Context:
|
||||
persona_manager: PersonaManager,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
knowledge_base_manager: KnowledgeBaseManager,
|
||||
memory_manager: MemoryManager,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
@@ -74,6 +81,154 @@ class Context:
|
||||
self.persona_manager = persona_manager
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
self.kb_manager = knowledge_base_manager
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
async def llm_generate(
|
||||
self,
|
||||
*,
|
||||
chat_provider_id: str,
|
||||
prompt: str | None = None,
|
||||
image_urls: list[str] | None = None,
|
||||
tools: ToolSet | None = None,
|
||||
system_prompt: str | None = None,
|
||||
contexts: list[Message] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`.
|
||||
|
||||
.. versionadded:: 4.5.7 (sdk)
|
||||
|
||||
Args:
|
||||
chat_provider_id: The chat provider ID to use.
|
||||
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
|
||||
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
|
||||
tools: ToolSet of tools available to the LLM
|
||||
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
|
||||
contexts: context messages for the LLM
|
||||
**kwargs: Additional keyword arguments for LLM generation, OpenAI compatible
|
||||
|
||||
Raises:
|
||||
ChatProviderNotFoundError: If the specified chat provider ID is not found
|
||||
Exception: For other errors during LLM generation
|
||||
"""
|
||||
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
|
||||
if not prov or not isinstance(prov, Provider):
|
||||
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt=prompt,
|
||||
image_urls=image_urls,
|
||||
func_tool=tools,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
return llm_resp
|
||||
|
||||
async def tool_loop_agent(
|
||||
self,
|
||||
*,
|
||||
event: AstrMessageEvent,
|
||||
chat_provider_id: str,
|
||||
prompt: str | None = None,
|
||||
image_urls: list[str] | None = None,
|
||||
tools: ToolSet | None = None,
|
||||
system_prompt: str | None = None,
|
||||
contexts: list[Message] | None = None,
|
||||
max_steps: int = 30,
|
||||
tool_call_timeout: int = 60,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced.
|
||||
If you do not pass the agent_context parameter, the method will recreate a new agent context.
|
||||
|
||||
.. versionadded:: 4.5.7 (sdk)
|
||||
|
||||
Args:
|
||||
chat_provider_id: The chat provider ID to use.
|
||||
prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message
|
||||
image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message
|
||||
tools: ToolSet of tools available to the LLM
|
||||
system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
|
||||
Raises:
|
||||
ChatProviderNotFoundError: If the specified chat provider ID is not found
|
||||
Exception: For other errors during LLM generation
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from astrbot.core.astr_agent_context import (
|
||||
AgentContextWrapper,
|
||||
AstrAgentContext,
|
||||
)
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
|
||||
prov = await self.provider_manager.get_provider_by_id(chat_provider_id)
|
||||
if not prov or not isinstance(prov, Provider):
|
||||
raise ProviderNotFoundError(f"Provider {chat_provider_id} not found")
|
||||
|
||||
agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]()
|
||||
agent_context = kwargs.get("agent_context")
|
||||
|
||||
context_ = []
|
||||
for msg in contexts or []:
|
||||
if isinstance(msg, Message):
|
||||
context_.append(msg.model_dump())
|
||||
else:
|
||||
context_.append(msg)
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=prompt,
|
||||
image_urls=image_urls or [],
|
||||
func_tool=tools,
|
||||
contexts=context_,
|
||||
system_prompt=system_prompt or "",
|
||||
)
|
||||
if agent_context is None:
|
||||
agent_context = AstrAgentContext(
|
||||
context=self,
|
||||
event=event,
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=agent_context,
|
||||
tool_call_timeout=tool_call_timeout,
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
pass
|
||||
llm_resp = agent_runner.get_final_llm_resp()
|
||||
if not llm_resp:
|
||||
raise Exception("Agent did not produce a final LLM response")
|
||||
return llm_resp
|
||||
|
||||
async def get_current_chat_provider_id(self, umo: str) -> str:
|
||||
"""Get the ID of the currently used chat provider.
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used.
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: If the specified chat provider is not found
|
||||
|
||||
"""
|
||||
prov = self.get_using_provider(umo)
|
||||
if not prov:
|
||||
raise ProviderNotFoundError("Provider not found")
|
||||
return prov.meta().id
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
||||
"""根据插件名获取插件的 Metadata"""
|
||||
@@ -107,10 +262,6 @@ class Context:
|
||||
"""
|
||||
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(
|
||||
self,
|
||||
provider_id: str,
|
||||
@@ -189,45 +340,6 @@ class Context:
|
||||
return self._config
|
||||
return self.astrbot_config_mgr.get_conf(umo)
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
"""获取 AstrBot 数据库。"""
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
"""获取事件队列。"""
|
||||
return self._event_queue
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
|
||||
"""获取指定类型的平台适配器。
|
||||
|
||||
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
name = platform.meta().name
|
||||
if isinstance(platform_type, str):
|
||||
if name == platform_type:
|
||||
return platform
|
||||
elif (
|
||||
name in ADAPTER_NAME_2_TYPE
|
||||
and ADAPTER_NAME_2_TYPE[name] & platform_type
|
||||
):
|
||||
return platform
|
||||
|
||||
def get_platform_inst(self, platform_id: str) -> Platform | None:
|
||||
"""获取指定 ID 的平台适配器实例。
|
||||
|
||||
Args:
|
||||
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
|
||||
|
||||
Returns:
|
||||
Platform: 平台适配器实例,如果未找到则返回 None。
|
||||
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().id == platform_id:
|
||||
return platform
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
session: str | MessageSesion,
|
||||
@@ -300,6 +412,49 @@ class Context:
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
"""
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
"""获取事件队列。"""
|
||||
return self._event_queue
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
|
||||
"""获取指定类型的平台适配器。
|
||||
|
||||
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
name = platform.meta().name
|
||||
if isinstance(platform_type, str):
|
||||
if name == platform_type:
|
||||
return platform
|
||||
elif (
|
||||
name in ADAPTER_NAME_2_TYPE
|
||||
and ADAPTER_NAME_2_TYPE[name] & platform_type
|
||||
):
|
||||
return platform
|
||||
|
||||
def get_platform_inst(self, platform_id: str) -> Platform | None:
|
||||
"""获取指定 ID 的平台适配器实例。
|
||||
|
||||
Args:
|
||||
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
|
||||
|
||||
Returns:
|
||||
Platform: 平台适配器实例,如果未找到则返回 None。
|
||||
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().id == platform_id:
|
||||
return platform
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
"""获取 AstrBot 数据库。"""
|
||||
return self._db
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def register_llm_tool(
|
||||
self,
|
||||
name: str,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
@@ -11,7 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from ..filter.command import CommandFilter
|
||||
@@ -417,18 +418,37 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
||||
docstring = docstring_parser.parse(func_doc)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
sub_type_name = None
|
||||
type_name = arg.type_name
|
||||
if not type_name:
|
||||
raise ValueError(
|
||||
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 的参数 {arg.arg_name} 缺少类型注释。",
|
||||
)
|
||||
# parse type_name to handle cases like "list[string]"
|
||||
match = re.match(r"(\w+)\[(\w+)\]", type_name)
|
||||
if match:
|
||||
type_name = match.group(1)
|
||||
sub_type_name = match.group(2)
|
||||
type_name = PY_TO_JSON_TYPE.get(type_name, type_name)
|
||||
if sub_type_name:
|
||||
sub_type_name = PY_TO_JSON_TYPE.get(sub_type_name, sub_type_name)
|
||||
if type_name not in SUPPORTED_TYPES or (
|
||||
sub_type_name and sub_type_name not in SUPPORTED_TYPES
|
||||
):
|
||||
raise ValueError(
|
||||
f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}",
|
||||
)
|
||||
args.append(
|
||||
{
|
||||
"type": arg.type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description,
|
||||
},
|
||||
)
|
||||
# print(llm_tool_name, registering_agent)
|
||||
|
||||
arg_json_schema = {
|
||||
"type": type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description,
|
||||
}
|
||||
if sub_type_name:
|
||||
if type_name == "array":
|
||||
arg_json_schema["items"] = {"type": sub_type_name}
|
||||
args.append(arg_json_schema)
|
||||
|
||||
if not registering_agent:
|
||||
doc_desc = docstring.description.strip() if docstring.description else ""
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
|
||||
@@ -5,6 +5,7 @@ from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
from .knowledge_base import KnowledgeBaseRoute
|
||||
from .log import LogRoute
|
||||
from .memory import MemoryRoute
|
||||
from .persona import PersonaRoute
|
||||
from .plugin import PluginRoute
|
||||
from .session_management import SessionManagementRoute
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"FileRoute",
|
||||
"KnowledgeBaseRoute",
|
||||
"LogRoute",
|
||||
"MemoryRoute",
|
||||
"PersonaRoute",
|
||||
"PluginRoute",
|
||||
"SessionManagementRoute",
|
||||
|
||||
@@ -10,7 +10,6 @@ from quart import g, make_response, request
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
@@ -36,11 +35,14 @@ class ChatRoute(Route):
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/chat/send": ("POST", self.chat),
|
||||
"/chat/new_conversation": ("GET", self.new_conversation),
|
||||
"/chat/conversations": ("GET", self.get_conversations),
|
||||
"/chat/get_conversation": ("GET", self.get_conversation),
|
||||
"/chat/delete_conversation": ("GET", self.delete_conversation),
|
||||
"/chat/rename_conversation": ("POST", self.rename_conversation),
|
||||
"/chat/new_session": ("GET", self.new_session),
|
||||
"/chat/sessions": ("GET", self.get_sessions),
|
||||
"/chat/get_session": ("GET", self.get_session),
|
||||
"/chat/delete_session": ("GET", self.delete_webchat_session),
|
||||
"/chat/update_session_display_name": (
|
||||
"POST",
|
||||
self.update_session_display_name,
|
||||
),
|
||||
"/chat/get_file": ("GET", self.get_file),
|
||||
"/chat/post_image": ("POST", self.post_image),
|
||||
"/chat/post_file": ("POST", self.post_file),
|
||||
@@ -53,6 +55,7 @@ class ChatRoute(Route):
|
||||
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
self.db = db
|
||||
|
||||
self.running_convs: dict[str, bool] = {}
|
||||
|
||||
@@ -116,26 +119,31 @@ class ChatRoute(Route):
|
||||
if "message" not in post_data and "image_url" not in post_data:
|
||||
return Response().error("Missing key: message or image_url").__dict__
|
||||
|
||||
if "conversation_id" not in post_data:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
if "session_id" not in post_data and "conversation_id" not in post_data:
|
||||
return (
|
||||
Response().error("Missing key: session_id or conversation_id").__dict__
|
||||
)
|
||||
|
||||
message = post_data["message"]
|
||||
conversation_id = post_data["conversation_id"]
|
||||
# conversation_id = post_data["conversation_id"]
|
||||
session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
||||
image_url = post_data.get("image_url")
|
||||
audio_url = post_data.get("audio_url")
|
||||
selected_provider = post_data.get("selected_provider")
|
||||
selected_model = post_data.get("selected_model")
|
||||
enable_streaming = post_data.get("enable_streaming", True) # 默认为 True
|
||||
|
||||
if not message and not image_url and not audio_url:
|
||||
return (
|
||||
Response()
|
||||
.error("Message and image_url and audio_url are empty")
|
||||
.__dict__
|
||||
)
|
||||
if not conversation_id:
|
||||
return Response().error("conversation_id is empty").__dict__
|
||||
if not session_id:
|
||||
return Response().error("session_id is empty").__dict__
|
||||
|
||||
# 追加用户消息
|
||||
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
||||
webchat_conv_id = session_id
|
||||
|
||||
# 获取会话特定的队列
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
||||
@@ -202,6 +210,8 @@ class ChatRoute(Route):
|
||||
):
|
||||
# 追加机器人消息
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
if "reasoning" in result:
|
||||
new_his["reasoning"] = result["reasoning"]
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
@@ -224,6 +234,7 @@ class ChatRoute(Route):
|
||||
"audio_url": audio_url,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"enable_streaming": enable_streaming,
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -240,88 +251,110 @@ class ChatRoute(Route):
|
||||
response.timeout = None # fix SSE auto disconnect issue
|
||||
return response
|
||||
|
||||
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
|
||||
"""从对话 ID 中提取 WebChat 会话 ID
|
||||
|
||||
NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。
|
||||
"""
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin="webchat",
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation with ID {conversation_id} not found.")
|
||||
conv_user_id = conversation.user_id
|
||||
webchat_session_id = MessageSession.from_str(conv_user_id).session_id
|
||||
if "!" not in webchat_session_id:
|
||||
raise ValueError(f"Invalid conv user ID: {conv_user_id}")
|
||||
return webchat_session_id.split("!")[-1]
|
||||
|
||||
async def delete_conversation(self):
|
||||
conversation_id = request.args.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
async def delete_webchat_session(self):
|
||||
"""Delete a Platform session and all its related data."""
|
||||
session_id = request.args.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("Missing key: session_id").__dict__
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# Clean up queues when deleting conversation
|
||||
webchat_queue_mgr.remove_queues(conversation_id)
|
||||
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
||||
await self.conv_mgr.delete_conversation(
|
||||
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
# 验证会话是否存在且属于当前用户
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if not session:
|
||||
return Response().error(f"Session {session_id} not found").__dict__
|
||||
if session.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
# 删除该会话下的所有对话
|
||||
unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}"
|
||||
await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
|
||||
|
||||
# 删除消息历史
|
||||
await self.platform_history_mgr.delete(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
platform_id=session.platform_id,
|
||||
user_id=session_id,
|
||||
offset_sec=99999999,
|
||||
)
|
||||
|
||||
# 清理队列(仅对 webchat)
|
||||
if session.platform_id == "webchat":
|
||||
webchat_queue_mgr.remove_queues(session_id)
|
||||
|
||||
# 删除会话
|
||||
await self.db.delete_platform_session(session_id)
|
||||
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def new_conversation(self):
|
||||
async def new_session(self):
|
||||
"""Create a new Platform session (default: webchat)."""
|
||||
username = g.get("username", "guest")
|
||||
webchat_conv_id = str(uuid.uuid4())
|
||||
conv_id = await self.conv_mgr.new_conversation(
|
||||
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
|
||||
platform_id="webchat",
|
||||
content=[],
|
||||
|
||||
# 获取可选的 platform_id 参数,默认为 webchat
|
||||
platform_id = request.args.get("platform_id", "webchat")
|
||||
|
||||
# 创建新会话
|
||||
session = await self.db.create_platform_session(
|
||||
creator=username,
|
||||
platform_id=platform_id,
|
||||
is_group=0,
|
||||
)
|
||||
return Response().ok(data={"conversation_id": conv_id}).__dict__
|
||||
|
||||
async def rename_conversation(self):
|
||||
post_data = await request.json
|
||||
if "conversation_id" not in post_data or "title" not in post_data:
|
||||
return Response().error("Missing key: conversation_id or title").__dict__
|
||||
|
||||
conversation_id = post_data["conversation_id"]
|
||||
title = post_data["title"]
|
||||
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin="webchat", # fake
|
||||
conversation_id=conversation_id,
|
||||
title=title,
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"session_id": session.session_id,
|
||||
"platform_id": session.platform_id,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
return Response().ok(message="重命名成功!").__dict__
|
||||
|
||||
async def get_conversations(self):
|
||||
conversations = await self.conv_mgr.get_conversations(platform_id="webchat")
|
||||
# remove content
|
||||
conversations_ = []
|
||||
for conv in conversations:
|
||||
conv.history = None
|
||||
conversations_.append(conv)
|
||||
return Response().ok(data=conversations_).__dict__
|
||||
async def get_sessions(self):
|
||||
"""Get all Platform sessions for the current user."""
|
||||
username = g.get("username", "guest")
|
||||
|
||||
async def get_conversation(self):
|
||||
conversation_id = request.args.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
# 获取可选的 platform_id 参数
|
||||
platform_id = request.args.get("platform_id")
|
||||
|
||||
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
||||
sessions = await self.db.get_platform_sessions_by_creator(
|
||||
creator=username,
|
||||
platform_id=platform_id,
|
||||
page=1,
|
||||
page_size=100, # 暂时返回前100个
|
||||
)
|
||||
|
||||
# Get platform message history
|
||||
# 转换为字典格式,并添加额外信息
|
||||
sessions_data = []
|
||||
for session in sessions:
|
||||
sessions_data.append(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"platform_id": session.platform_id,
|
||||
"creator": session.creator,
|
||||
"display_name": session.display_name,
|
||||
"is_group": session.is_group,
|
||||
"created_at": session.created_at.astimezone().isoformat(),
|
||||
"updated_at": session.updated_at.astimezone().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return Response().ok(data=sessions_data).__dict__
|
||||
|
||||
async def get_session(self):
|
||||
"""Get session information and message history by session_id."""
|
||||
session_id = request.args.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("Missing key: session_id").__dict__
|
||||
|
||||
# 获取会话信息以确定 platform_id
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
platform_id = session.platform_id if session else "webchat"
|
||||
|
||||
# Get platform message history using session_id
|
||||
history_ls = await self.platform_history_mgr.get(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
platform_id=platform_id,
|
||||
user_id=session_id,
|
||||
page=1,
|
||||
page_size=1000,
|
||||
)
|
||||
@@ -333,8 +366,37 @@ class ChatRoute(Route):
|
||||
.ok(
|
||||
data={
|
||||
"history": history_res,
|
||||
"is_running": self.running_convs.get(webchat_conv_id, False),
|
||||
"is_running": self.running_convs.get(session_id, False),
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def update_session_display_name(self):
|
||||
"""Update a Platform session's display name."""
|
||||
post_data = await request.json
|
||||
|
||||
session_id = post_data.get("session_id")
|
||||
display_name = post_data.get("display_name")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("Missing key: session_id").__dict__
|
||||
if display_name is None:
|
||||
return Response().error("Missing key: display_name").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# 验证会话是否存在且属于当前用户
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if not session:
|
||||
return Response().error(f"Session {session_id} not found").__dict__
|
||||
if session.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
# 更新 display_name
|
||||
await self.db.update_platform_session(
|
||||
session_id=session_id,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
return Response().ok().__dict__
|
||||
|
||||
@@ -48,6 +48,7 @@ class KnowledgeBaseRoute(Route):
|
||||
# 文档管理
|
||||
"/kb/document/list": ("GET", self.list_documents),
|
||||
"/kb/document/upload": ("POST", self.upload_document),
|
||||
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
|
||||
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
||||
"/kb/document/get": ("GET", self.get_document),
|
||||
"/kb/document/delete": ("POST", self.delete_document),
|
||||
@@ -1070,3 +1071,174 @@ class KnowledgeBaseRoute(Route):
|
||||
logger.error(f"删除会话知识库配置失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__
|
||||
|
||||
async def upload_document_from_url(self):
|
||||
"""从 URL 上传文档
|
||||
|
||||
Body:
|
||||
- kb_id: 知识库 ID (必填)
|
||||
- url: 要提取内容的网页 URL (必填)
|
||||
- chunk_size: 分块大小 (可选, 默认512)
|
||||
- chunk_overlap: 块重叠大小 (可选, 默认50)
|
||||
- batch_size: 批处理大小 (可选, 默认32)
|
||||
- tasks_limit: 并发任务限制 (可选, 默认3)
|
||||
- max_retries: 最大重试次数 (可选, 默认3)
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询上传进度和结果
|
||||
"""
|
||||
try:
|
||||
kb_manager = self._get_kb_manager()
|
||||
data = await request.json
|
||||
|
||||
kb_id = data.get("kb_id")
|
||||
if not kb_id:
|
||||
return Response().error("缺少参数 kb_id").__dict__
|
||||
|
||||
url = data.get("url")
|
||||
if not url:
|
||||
return Response().error("缺少参数 url").__dict__
|
||||
|
||||
chunk_size = data.get("chunk_size", 512)
|
||||
chunk_overlap = data.get("chunk_overlap", 50)
|
||||
batch_size = data.get("batch_size", 32)
|
||||
tasks_limit = data.get("tasks_limit", 3)
|
||||
max_retries = data.get("max_retries", 3)
|
||||
enable_cleaning = data.get("enable_cleaning", False)
|
||||
cleaning_provider_id = data.get("cleaning_provider_id")
|
||||
|
||||
# 获取知识库
|
||||
kb_helper = await kb_manager.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return Response().error("知识库不存在").__dict__
|
||||
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化任务状态
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "pending",
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# 启动后台任务
|
||||
asyncio.create_task(
|
||||
self._background_upload_from_url_task(
|
||||
task_id=task_id,
|
||||
kb_helper=kb_helper,
|
||||
url=url,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
enable_cleaning=enable_cleaning,
|
||||
cleaning_provider_id=cleaning_provider_id,
|
||||
),
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"url": url,
|
||||
"message": "URL upload task created, processing in background",
|
||||
},
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"从URL上传文档失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"从URL上传文档失败: {e!s}").__dict__
|
||||
|
||||
async def _background_upload_from_url_task(
|
||||
self,
|
||||
task_id: str,
|
||||
kb_helper,
|
||||
url: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
batch_size: int,
|
||||
tasks_limit: int,
|
||||
max_retries: int,
|
||||
enable_cleaning: bool,
|
||||
cleaning_provider_id: str | None,
|
||||
):
|
||||
"""后台上传URL任务"""
|
||||
try:
|
||||
# 初始化任务状态
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "processing",
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
self.upload_progress[task_id] = {
|
||||
"status": "processing",
|
||||
"file_index": 0,
|
||||
"file_total": 1,
|
||||
"file_name": f"URL: {url}",
|
||||
"stage": "extracting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
}
|
||||
|
||||
# 创建进度回调函数
|
||||
async def progress_callback(stage, current, total):
|
||||
if task_id in self.upload_progress:
|
||||
self.upload_progress[task_id].update(
|
||||
{
|
||||
"status": "processing",
|
||||
"file_index": 0,
|
||||
"file_name": f"URL: {url}",
|
||||
"stage": stage,
|
||||
"current": current,
|
||||
"total": total,
|
||||
},
|
||||
)
|
||||
|
||||
# 上传文档
|
||||
doc = await kb_helper.upload_from_url(
|
||||
url=url,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
enable_cleaning=enable_cleaning,
|
||||
cleaning_provider_id=cleaning_provider_id,
|
||||
)
|
||||
|
||||
# 更新任务完成状态
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"uploaded": [doc.model_dump()],
|
||||
"failed": [],
|
||||
"total": 1,
|
||||
"success_count": 1,
|
||||
"failed_count": 0,
|
||||
}
|
||||
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "completed",
|
||||
"result": result,
|
||||
"error": None,
|
||||
}
|
||||
self.upload_progress[task_id]["status"] = "completed"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": "failed",
|
||||
"result": None,
|
||||
"error": str(e),
|
||||
}
|
||||
if task_id in self.upload_progress:
|
||||
self.upload_progress[task_id]["status"] = "failed"
|
||||
|
||||
174
astrbot/dashboard/routes/memory.py
Normal file
174
astrbot/dashboard/routes/memory.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Memory management API routes"""
|
||||
|
||||
from quart import jsonify, request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
class MemoryRoute(Route):
|
||||
"""Memory management routes"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
):
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.memory_manager = core_lifecycle.memory_manager
|
||||
self.provider_manager = core_lifecycle.provider_manager
|
||||
self.routes = [
|
||||
("/memory/status", ("GET", self.get_status)),
|
||||
("/memory/initialize", ("POST", self.initialize)),
|
||||
("/memory/update_merge_llm", ("POST", self.update_merge_llm)),
|
||||
]
|
||||
self.register_routes()
|
||||
|
||||
async def get_status(self):
|
||||
"""Get memory system status"""
|
||||
try:
|
||||
is_initialized = self.memory_manager._initialized
|
||||
|
||||
status_data = {
|
||||
"initialized": is_initialized,
|
||||
"embedding_provider_id": None,
|
||||
"merge_llm_provider_id": None,
|
||||
}
|
||||
|
||||
if is_initialized:
|
||||
# Get embedding provider info
|
||||
if self.memory_manager.embedding_provider:
|
||||
status_data["embedding_provider_id"] = (
|
||||
self.memory_manager.embedding_provider.provider_config["id"]
|
||||
)
|
||||
# Get merge LLM provider info
|
||||
if self.memory_manager.merge_llm_provider:
|
||||
status_data["merge_llm_provider_id"] = (
|
||||
self.memory_manager.merge_llm_provider.provider_config["id"]
|
||||
)
|
||||
|
||||
return jsonify(Response().ok(status_data).__dict__)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory status: {e}")
|
||||
return jsonify(Response().error(str(e)).__dict__)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize memory system with embedding and merge LLM providers"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
embedding_provider_id = data.get("embedding_provider_id")
|
||||
merge_llm_provider_id = data.get("merge_llm_provider_id")
|
||||
|
||||
if not embedding_provider_id or not merge_llm_provider_id:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(
|
||||
"embedding_provider_id and merge_llm_provider_id are required"
|
||||
)
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Check if already initialized
|
||||
if self.memory_manager._initialized:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(
|
||||
"Memory system already initialized. Embedding provider cannot be changed.",
|
||||
)
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Get providers
|
||||
embedding_provider = await self.provider_manager.get_provider_by_id(
|
||||
embedding_provider_id,
|
||||
)
|
||||
merge_llm_provider = await self.provider_manager.get_provider_by_id(
|
||||
merge_llm_provider_id,
|
||||
)
|
||||
|
||||
if not embedding_provider:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(f"Embedding provider {embedding_provider_id} not found")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
if not merge_llm_provider:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(f"Merge LLM provider {merge_llm_provider_id} not found")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Initialize memory manager
|
||||
await self.memory_manager.initialize(
|
||||
embedding_provider=embedding_provider,
|
||||
merge_llm_provider=merge_llm_provider,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Memory system initialized with embedding: {embedding_provider_id}, "
|
||||
f"merge LLM: {merge_llm_provider_id}",
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
Response()
|
||||
.ok({"message": "Memory system initialized successfully"})
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize memory system: {e}")
|
||||
return jsonify(Response().error(str(e)).__dict__)
|
||||
|
||||
async def update_merge_llm(self):
|
||||
"""Update merge LLM provider (only allowed after initialization)"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
merge_llm_provider_id = data.get("merge_llm_provider_id")
|
||||
|
||||
if not merge_llm_provider_id:
|
||||
return jsonify(
|
||||
Response().error("merge_llm_provider_id is required").__dict__,
|
||||
)
|
||||
|
||||
# Check if initialized
|
||||
if not self.memory_manager._initialized:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error("Memory system not initialized. Please initialize first.")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Get new merge LLM provider
|
||||
merge_llm_provider = await self.provider_manager.get_provider_by_id(
|
||||
merge_llm_provider_id,
|
||||
)
|
||||
|
||||
if not merge_llm_provider:
|
||||
return jsonify(
|
||||
Response()
|
||||
.error(f"Merge LLM provider {merge_llm_provider_id} not found")
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
# Update merge LLM provider
|
||||
self.memory_manager.merge_llm_provider = merge_llm_provider
|
||||
|
||||
logger.info(f"Updated merge LLM provider to: {merge_llm_provider_id}")
|
||||
|
||||
return jsonify(
|
||||
Response()
|
||||
.ok({"message": "Merge LLM provider updated successfully"})
|
||||
.__dict__,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update merge LLM provider: {e}")
|
||||
return jsonify(Response().error(str(e)).__dict__)
|
||||
@@ -296,7 +296,15 @@ class ToolsRoute(Route):
|
||||
"""获取所有注册的工具列表"""
|
||||
try:
|
||||
tools = self.tool_mgr.func_list
|
||||
tools_dict = [tool.__dict__() for tool in tools]
|
||||
tools_dict = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"active": tool.active,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
return Response().ok(data=tools_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -79,6 +79,7 @@ class AstrBotDashboard:
|
||||
self.persona_route = PersonaRoute(self.context, db, core_lifecycle)
|
||||
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
||||
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
||||
self.memory_route = MemoryRoute(self.context, db, core_lifecycle)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
|
||||
5
changelogs/v4.5.3.md
Normal file
5
changelogs/v4.5.3.md
Normal file
@@ -0,0 +1,5 @@
|
||||
## What's Changed
|
||||
|
||||
> hotfix version of 4.5.2
|
||||
|
||||
1. 修复:修正 `get_tool_list` 方法中工具字典推导式的错误导致的 WebUI MCP 页面工具列表无法显示的问题。
|
||||
5
changelogs/v4.5.4.md
Normal file
5
changelogs/v4.5.4.md
Normal file
@@ -0,0 +1,5 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复:Docker 镜像部分依赖问题导致某些情况下无法启动容器的问题;
|
||||
2. 优化:插件卡片样式
|
||||
3. 修复:部分情况下 Windows 一键启动部署时,更新 / 部署失败的问题;
|
||||
3
changelogs/v4.5.5.md
Normal file
3
changelogs/v4.5.5.md
Normal file
@@ -0,0 +1,3 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复:部署失败
|
||||
3
changelogs/v4.5.6.md
Normal file
3
changelogs/v4.5.6.md
Normal file
@@ -0,0 +1,3 @@
|
||||
## What's Changed
|
||||
|
||||
1. 修复:构建失败
|
||||
12
changelogs/v4.5.7.md
Normal file
12
changelogs/v4.5.7.md
Normal file
@@ -0,0 +1,12 @@
|
||||
## What's Changed
|
||||
|
||||
1. 新增:支持为 OpenAI API 提供商自定义请求头 ([#3581](https://github.com/AstrBotDevs/AstrBot/issues/3581))
|
||||
2. 新增:为 WebChat 为 Thinking 模型添加思考过程展示功能;支持快捷切换流式输出 / 非流式输出。([#3632](https://github.com/AstrBotDevs/AstrBot/issues/3632))
|
||||
3. 新增:优化插件调用 LLM 和 Agent 的路径,为 Context 类引入多个调用 LLM 和 Agent 的便捷方法 ([#3636](https://github.com/AstrBotDevs/AstrBot/issues/3636))
|
||||
4. 优化:改善不支持流式输出的消息平台的回退策略 ([#3547](https://github.com/AstrBotDevs/AstrBot/issues/3547))
|
||||
5. 优化:当同一个会话(umo)下同时有多个请求时,执行排队处理,避免并发请求导致的上下文混乱问题 ([#3607](https://github.com/AstrBotDevs/AstrBot/issues/3607))
|
||||
6. 优化:优化 WebUI 的登录界面和 Changelog 页面的显示效果
|
||||
7. 修复:修复在知识库名字过长的情况下,“选择知识库”按钮显示异常的问题 ([#3582](https://github.com/AstrBotDevs/AstrBot/issues/3582))
|
||||
8. 修复:修复部分情况下,分段消息发送时导致的死锁问题(由 PR #3607 引入)
|
||||
9. 修复:钉钉适配器使用部分指令无法生效的问题 ([#3634](https://github.com/AstrBotDevs/AstrBot/issues/3634))
|
||||
10. 其他:为部分适配器添加缺失的 send_streaming 方法 ([#3545](https://github.com/AstrBotDevs/AstrBot/issues/3545))
|
||||
5
changelogs/v4.5.8.md
Normal file
5
changelogs/v4.5.8.md
Normal file
@@ -0,0 +1,5 @@
|
||||
## What's Changed
|
||||
|
||||
hot fix of 4.5.7
|
||||
|
||||
fix: 无法正常发送图片,报错 `pydantic_core._pydantic_core.ValidationError`
|
||||
23
changelogs/v4.6.0.md
Normal file
23
changelogs/v4.6.0.md
Normal file
@@ -0,0 +1,23 @@
|
||||
## What's Changed
|
||||
|
||||
1. 新增: 支持 gemini-3 系列的 thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
|
||||
2. 新增: 支持知识库的 Agentic 检索功能 ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
|
||||
3. 新增: 为知识库添加 URL 文档解析器 ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
|
||||
4. 修复(core.platform): 修复启用多个企业微信智能机器人适配器时消息混乱的问题 ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
|
||||
5. 修复: MCP Server 连接成功一段时间后,调用 mcp 工具时可能出现 `anyio.ClosedResourceError` 错误 ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
|
||||
6. 新增(chat): 重构聊天组件结构并添加新功能 ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
|
||||
7. 修复(dashboard.i18n): 完善缺失的英文国际化键值 ([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
|
||||
8. 重构: 实现 WebChat 会话管理及从版本 4.6 迁移到 4.7
|
||||
9. 持续集成(docker-build): 每日构建 Nightly 版本 Docker 镜像 ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))
|
||||
|
||||
---
|
||||
|
||||
1. feat: add supports for gemini-3 series thought signature ([#3698](https://github.com/AstrBotDevs/AstrBot/issues/3698))
|
||||
2. feat: supports knowledge base agentic search ([#3667](https://github.com/AstrBotDevs/AstrBot/issues/3667))
|
||||
3. feat: Add URL document parser for knowledge base ([#3622](https://github.com/AstrBotDevs/AstrBot/issues/3622))
|
||||
4. fix(core.platform): fix message mix-up issue when enabling multiple WeCom AI Bot adapters ([#3693](https://github.com/AstrBotDevs/AstrBot/issues/3693))
|
||||
5. fix: fix `anyio.ClosedResourceError` that may occur when calling mcp tools after a period of successful connection to MCP Server ([#3700](https://github.com/AstrBotDevs/AstrBot/issues/3700))
|
||||
6. feat(chat): refactor chat component structure and add new features ([#3701](https://github.com/AstrBotDevs/AstrBot/issues/3701))
|
||||
7. fix(dashboard.i18n): complete the missing i18n keys for en([#3699](https://github.com/AstrBotDevs/AstrBot/issues/3699))
|
||||
8. refactor: Implement WebChat session management and migration from version 4.6 to 4.7
|
||||
9. ci(docker-build): build nightly image everyday ([#3120](https://github.com/AstrBotDevs/AstrBot/issues/3120))
|
||||
1
dashboard/src/assets/images/icon-no-shadow.svg
Normal file
1
dashboard/src/assets/images/icon-no-shadow.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 58 KiB |
File diff suppressed because it is too large
Load Diff
283
dashboard/src/components/chat/ChatInput.vue
Normal file
283
dashboard/src/components/chat/ChatInput.vue
Normal file
@@ -0,0 +1,283 @@
|
||||
<template>
|
||||
<div class="input-area fade-in">
|
||||
<div class="input-container"
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px;">
|
||||
<textarea
|
||||
ref="inputField"
|
||||
v-model="localPrompt"
|
||||
@keydown="handleKeyDown"
|
||||
:disabled="disabled"
|
||||
placeholder="Ask AstrBot..."
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
<div style="display: flex; justify-content: space-between; align-items: center; padding: 0px 12px;">
|
||||
<div style="display: flex; justify-content: flex-start; margin-top: 4px; align-items: center; gap: 8px;">
|
||||
<ProviderModelSelector ref="providerModelSelectorRef" />
|
||||
|
||||
<v-tooltip :text="enableStreaming ? tm('streaming.enabled') : tm('streaming.disabled')" location="top">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-chip v-bind="props" @click="$emit('toggleStreaming')" size="x-small" class="streaming-toggle-chip">
|
||||
<v-icon start :icon="enableStreaming ? 'mdi-flash' : 'mdi-flash-off'" size="small"></v-icon>
|
||||
{{ enableStreaming ? tm('streaming.on') : tm('streaming.off') }}
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
</div>
|
||||
<div style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
|
||||
<input type="file" ref="imageInputRef" @change="handleFileSelect" accept="image/*"
|
||||
style="display: none" multiple />
|
||||
<v-progress-circular v-if="disabled" indeterminate size="16" class="mr-1" width="1.5" />
|
||||
<v-btn @click="triggerImageInput" icon="mdi-plus" variant="text" color="deep-purple"
|
||||
class="add-btn" size="small" />
|
||||
<v-btn @click="handleRecordClick"
|
||||
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||
:color="isRecording ? 'error' : 'deep-purple'" class="record-btn" size="small" />
|
||||
<v-btn @click="$emit('send')" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!canSend" class="send-btn" size="small" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 附件预览区 -->
|
||||
<div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl">
|
||||
<div v-for="(img, index) in stagedImagesUrl" :key="index" class="image-preview">
|
||||
<img :src="img" class="preview-image" />
|
||||
<v-btn @click="$emit('removeImage', index)" class="remove-attachment-btn" icon="mdi-close"
|
||||
size="small" color="error" variant="text" />
|
||||
</div>
|
||||
|
||||
<div v-if="stagedAudioUrl" class="audio-preview">
|
||||
<v-chip color="deep-purple-lighten-4" class="audio-chip">
|
||||
<v-icon start icon="mdi-microphone" size="small"></v-icon>
|
||||
{{ tm('voice.recording') }}
|
||||
</v-chip>
|
||||
<v-btn @click="$emit('removeAudio')" class="remove-attachment-btn" icon="mdi-close" size="small"
|
||||
color="error" variant="text" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, onBeforeUnmount, watch } from 'vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import ProviderModelSelector from './ProviderModelSelector.vue';
|
||||
|
||||
interface Props {
|
||||
prompt: string;
|
||||
stagedImagesUrl: string[];
|
||||
stagedAudioUrl: string;
|
||||
disabled: boolean;
|
||||
enableStreaming: boolean;
|
||||
isRecording: boolean;
|
||||
}
|
||||
|
||||
const props = defineProps<Props>();
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:prompt': [value: string];
|
||||
send: [];
|
||||
toggleStreaming: [];
|
||||
removeImage: [index: number];
|
||||
removeAudio: [];
|
||||
startRecording: [];
|
||||
stopRecording: [];
|
||||
pasteImage: [event: ClipboardEvent];
|
||||
fileSelect: [files: FileList];
|
||||
}>();
|
||||
|
||||
const { tm } = useModuleI18n('features/chat');
|
||||
|
||||
const inputField = ref<HTMLTextAreaElement | null>(null);
|
||||
const imageInputRef = ref<HTMLInputElement | null>(null);
|
||||
const providerModelSelectorRef = ref<InstanceType<typeof ProviderModelSelector> | null>(null);
|
||||
|
||||
const localPrompt = computed({
|
||||
get: () => props.prompt,
|
||||
set: (value) => emit('update:prompt', value)
|
||||
});
|
||||
|
||||
const canSend = computed(() => {
|
||||
return (props.prompt && props.prompt.trim()) || props.stagedImagesUrl.length > 0 || props.stagedAudioUrl;
|
||||
});
|
||||
|
||||
// Ctrl+B 长按录音相关
|
||||
const ctrlKeyDown = ref(false);
|
||||
const ctrlKeyTimer = ref<number | null>(null);
|
||||
const ctrlKeyLongPressThreshold = 300;
|
||||
|
||||
function handleKeyDown(e: KeyboardEvent) {
|
||||
// Enter 发送消息
|
||||
if (e.keyCode === 13 && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
if (canSend.value) {
|
||||
emit('send');
|
||||
}
|
||||
}
|
||||
|
||||
// Ctrl+B 录音
|
||||
if (e.ctrlKey && e.keyCode === 66) {
|
||||
e.preventDefault();
|
||||
if (ctrlKeyDown.value) return;
|
||||
|
||||
ctrlKeyDown.value = true;
|
||||
ctrlKeyTimer.value = window.setTimeout(() => {
|
||||
if (ctrlKeyDown.value && !props.isRecording) {
|
||||
emit('startRecording');
|
||||
}
|
||||
}, ctrlKeyLongPressThreshold);
|
||||
}
|
||||
}
|
||||
|
||||
function handleKeyUp(e: KeyboardEvent) {
|
||||
if (e.keyCode === 66) {
|
||||
ctrlKeyDown.value = false;
|
||||
|
||||
if (ctrlKeyTimer.value) {
|
||||
clearTimeout(ctrlKeyTimer.value);
|
||||
ctrlKeyTimer.value = null;
|
||||
}
|
||||
|
||||
if (props.isRecording) {
|
||||
emit('stopRecording');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handlePaste(e: ClipboardEvent) {
|
||||
emit('pasteImage', e);
|
||||
}
|
||||
|
||||
function triggerImageInput() {
|
||||
imageInputRef.value?.click();
|
||||
}
|
||||
|
||||
function handleFileSelect(event: Event) {
|
||||
const target = event.target as HTMLInputElement;
|
||||
const files = target.files;
|
||||
if (files) {
|
||||
emit('fileSelect', files);
|
||||
}
|
||||
target.value = '';
|
||||
}
|
||||
|
||||
function handleRecordClick() {
|
||||
if (props.isRecording) {
|
||||
emit('stopRecording');
|
||||
} else {
|
||||
emit('startRecording');
|
||||
}
|
||||
}
|
||||
|
||||
function getCurrentSelection() {
|
||||
return providerModelSelectorRef.value?.getCurrentSelection();
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
if (inputField.value) {
|
||||
inputField.value.addEventListener('paste', handlePaste);
|
||||
}
|
||||
document.addEventListener('keyup', handleKeyUp);
|
||||
});
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
if (inputField.value) {
|
||||
inputField.value.removeEventListener('paste', handlePaste);
|
||||
}
|
||||
document.removeEventListener('keyup', handleKeyUp);
|
||||
});
|
||||
|
||||
defineExpose({
|
||||
getCurrentSelection
|
||||
});
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.input-area {
|
||||
padding: 16px;
|
||||
background-color: var(--v-theme-surface);
|
||||
position: relative;
|
||||
border-top: 1px solid var(--v-theme-border);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.attachments-preview {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-top: 8px;
|
||||
max-width: 900px;
|
||||
margin: 8px auto 0;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.image-preview,
|
||||
.audio-preview {
|
||||
position: relative;
|
||||
display: inline-flex;
|
||||
}
|
||||
|
||||
.preview-image {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
object-fit: cover;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.audio-chip {
|
||||
height: 36px;
|
||||
border-radius: 18px;
|
||||
}
|
||||
|
||||
.remove-attachment-btn {
|
||||
position: absolute;
|
||||
top: -8px;
|
||||
right: -8px;
|
||||
opacity: 0.8;
|
||||
transition: opacity 0.2s;
|
||||
}
|
||||
|
||||
.remove-attachment-btn:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.streaming-toggle-chip {
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.streaming-toggle-chip:hover {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.input-area {
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
width: 100% !important;
|
||||
max-width: 100% !important;
|
||||
margin: 0 !important;
|
||||
border-radius: 0 !important;
|
||||
border-left: none !important;
|
||||
border-right: none !important;
|
||||
border-bottom: none !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
295
dashboard/src/components/chat/ConversationSidebar.vue
Normal file
295
dashboard/src/components/chat/ConversationSidebar.vue
Normal file
@@ -0,0 +1,295 @@
|
||||
<template>
|
||||
<div class="sidebar-panel"
|
||||
:class="{
|
||||
'sidebar-collapsed': sidebarCollapsed && !isMobile,
|
||||
'mobile-sidebar-open': isMobile && mobileMenuOpen,
|
||||
'mobile-sidebar': isMobile
|
||||
}"
|
||||
:style="{ 'background-color': isDark ? sidebarCollapsed ? '#1e1e1e' : '#2d2d2d' : sidebarCollapsed ? '#ffffff' : '#f1f4f9' }"
|
||||
@mouseenter="handleSidebarMouseEnter"
|
||||
@mouseleave="handleSidebarMouseLeave">
|
||||
|
||||
<div style="display: flex; align-items: center; justify-content: center; padding: 16px; padding-bottom: 0px;"
|
||||
v-if="chatboxMode">
|
||||
<img width="50" src="@/assets/images/icon-no-shadow.svg" alt="AstrBot Logo">
|
||||
<span v-if="!sidebarCollapsed"
|
||||
style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-collapse-btn-container" v-if="!isMobile">
|
||||
<v-btn icon class="sidebar-collapse-btn" @click="toggleSidebar" variant="text" color="deep-purple">
|
||||
<v-icon>{{ (sidebarCollapsed || (!sidebarCollapsed && sidebarHoverExpanded)) ?
|
||||
'mdi-chevron-right' : 'mdi-chevron-left' }}</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-collapse-btn-container" v-if="isMobile">
|
||||
<v-btn icon class="sidebar-collapse-btn" @click="$emit('closeMobileSidebar')" variant="text"
|
||||
color="deep-purple">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div style="padding: 16px; padding-top: 8px;">
|
||||
<v-btn block variant="text" class="new-chat-btn" @click="$emit('newChat')" :disabled="!currSessionId"
|
||||
v-if="!sidebarCollapsed || isMobile" prepend-icon="mdi-plus"
|
||||
style="background-color: transparent !important; border-radius: 4px;">{{ tm('actions.newChat') }}</v-btn>
|
||||
<v-btn icon="mdi-plus" rounded="lg" @click="$emit('newChat')" :disabled="!currSessionId"
|
||||
v-if="sidebarCollapsed && !isMobile" elevation="0"></v-btn>
|
||||
</div>
|
||||
|
||||
<div v-if="!sidebarCollapsed || isMobile">
|
||||
<v-divider class="mx-4"></v-divider>
|
||||
</div>
|
||||
|
||||
<div style="overflow-y: auto; flex-grow: 1;" :class="{ 'fade-in': sidebarHoverExpanded }"
|
||||
v-if="!sidebarCollapsed || isMobile">
|
||||
<v-card v-if="sessions.length > 0" flat style="background-color: transparent;">
|
||||
<v-list density="compact" nav class="conversation-list"
|
||||
style="background-color: transparent;" :selected="selectedSessions"
|
||||
@update:selected="$emit('selectConversation', $event)">
|
||||
<v-list-item v-for="item in sessions" :key="item.session_id" :value="item.session_id"
|
||||
rounded="lg" class="conversation-item" active-color="secondary">
|
||||
<v-list-item-title v-if="!sidebarCollapsed || isMobile" class="conversation-title">
|
||||
{{ item.display_name || tm('conversation.newConversation') }}
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="!sidebarCollapsed || isMobile" class="timestamp">
|
||||
{{ new Date(item.updated_at).toLocaleString() }}
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-if="!sidebarCollapsed || isMobile" v-slot:append>
|
||||
<div class="conversation-actions">
|
||||
<v-btn icon="mdi-pencil" size="x-small" variant="text"
|
||||
class="edit-title-btn"
|
||||
@click.stop="$emit('editTitle', item.session_id, item.display_name)" />
|
||||
<v-btn icon="mdi-delete" size="x-small" variant="text"
|
||||
class="delete-conversation-btn" color="error"
|
||||
@click.stop="$emit('deleteConversation', item.session_id)" />
|
||||
</div>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-card>
|
||||
|
||||
<v-fade-transition>
|
||||
<div class="no-conversations" v-if="sessions.length === 0">
|
||||
<v-icon icon="mdi-message-text-outline" size="large" color="grey-lighten-1"></v-icon>
|
||||
<div class="no-conversations-text" v-if="!sidebarCollapsed || sidebarHoverExpanded || isMobile">
|
||||
{{ tm('conversation.noHistory') }}
|
||||
</div>
|
||||
</div>
|
||||
</v-fade-transition>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
import type { Session } from '@/composables/useSessions';
|
||||
|
||||
interface Props {
|
||||
sessions: Session[];
|
||||
selectedSessions: string[];
|
||||
currSessionId: string;
|
||||
isDark: boolean;
|
||||
chatboxMode: boolean;
|
||||
isMobile: boolean;
|
||||
mobileMenuOpen: boolean;
|
||||
}
|
||||
|
||||
const props = defineProps<Props>();
|
||||
|
||||
const emit = defineEmits<{
|
||||
newChat: [];
|
||||
selectConversation: [sessionIds: string[]];
|
||||
editTitle: [sessionId: string, title: string];
|
||||
deleteConversation: [sessionId: string];
|
||||
closeMobileSidebar: [];
|
||||
}>();
|
||||
|
||||
const { tm } = useModuleI18n('features/chat');
|
||||
const { t } = useI18n();
|
||||
|
||||
const sidebarCollapsed = ref(true);
|
||||
const sidebarHovered = ref(false);
|
||||
const sidebarHoverTimer = ref<number | null>(null);
|
||||
const sidebarHoverExpanded = ref(false);
|
||||
const sidebarHoverDelay = 100;
|
||||
|
||||
// 从 localStorage 读取侧边栏折叠状态
|
||||
const savedCollapsedState = localStorage.getItem('sidebarCollapsed');
|
||||
if (savedCollapsedState !== null) {
|
||||
sidebarCollapsed.value = JSON.parse(savedCollapsedState);
|
||||
} else {
|
||||
sidebarCollapsed.value = true;
|
||||
}
|
||||
|
||||
function toggleSidebar() {
|
||||
if (sidebarHoverExpanded.value) {
|
||||
sidebarHoverExpanded.value = false;
|
||||
return;
|
||||
}
|
||||
sidebarCollapsed.value = !sidebarCollapsed.value;
|
||||
localStorage.setItem('sidebarCollapsed', JSON.stringify(sidebarCollapsed.value));
|
||||
}
|
||||
|
||||
function handleSidebarMouseEnter() {
|
||||
if (!sidebarCollapsed.value || props.isMobile) return;
|
||||
|
||||
sidebarHovered.value = true;
|
||||
sidebarHoverTimer.value = window.setTimeout(() => {
|
||||
if (sidebarHovered.value) {
|
||||
sidebarHoverExpanded.value = true;
|
||||
sidebarCollapsed.value = false;
|
||||
}
|
||||
}, sidebarHoverDelay);
|
||||
}
|
||||
|
||||
function handleSidebarMouseLeave() {
|
||||
sidebarHovered.value = false;
|
||||
|
||||
if (sidebarHoverTimer.value) {
|
||||
clearTimeout(sidebarHoverTimer.value);
|
||||
sidebarHoverTimer.value = null;
|
||||
}
|
||||
|
||||
if (sidebarHoverExpanded.value) {
|
||||
sidebarCollapsed.value = true;
|
||||
}
|
||||
sidebarHoverExpanded.value = false;
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.sidebar-panel {
|
||||
max-width: 270px;
|
||||
min-width: 240px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 0;
|
||||
border-right: 1px solid rgba(0, 0, 0, 0.04);
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
position: relative;
|
||||
transition: all 0.3s ease;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar-collapsed {
|
||||
max-width: 75px;
|
||||
min-width: 75px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.mobile-sidebar {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
bottom: 0;
|
||||
max-width: 280px !important;
|
||||
min-width: 280px !important;
|
||||
transform: translateX(-100%);
|
||||
transition: transform 0.3s ease;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
.mobile-sidebar-open {
|
||||
transform: translateX(0) !important;
|
||||
}
|
||||
|
||||
.sidebar-collapse-btn-container {
|
||||
margin: 16px;
|
||||
margin-bottom: 0px;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.sidebar-collapse-btn {
|
||||
opacity: 0.6;
|
||||
max-height: none;
|
||||
overflow-y: visible;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.conversation-item {
|
||||
margin-bottom: 4px;
|
||||
border-radius: 8px !important;
|
||||
transition: all 0.2s ease;
|
||||
height: auto !important;
|
||||
min-height: 56px;
|
||||
padding: 8px 16px !important;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.conversation-item:hover {
|
||||
background-color: rgba(103, 58, 183, 0.05);
|
||||
}
|
||||
|
||||
.conversation-item:hover .conversation-actions {
|
||||
opacity: 1;
|
||||
visibility: visible;
|
||||
}
|
||||
|
||||
.conversation-actions {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.edit-title-btn,
|
||||
.delete-conversation-btn {
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.edit-title-btn:hover,
|
||||
.delete-conversation-btn:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.conversation-title {
|
||||
font-weight: 500;
|
||||
font-size: 14px;
|
||||
line-height: 1.3;
|
||||
margin-bottom: 2px;
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
.timestamp {
|
||||
font-size: 11px;
|
||||
color: var(--v-theme-secondaryText);
|
||||
line-height: 1;
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
.no-conversations {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 150px;
|
||||
opacity: 0.6;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.no-conversations-text {
|
||||
font-size: 14px;
|
||||
color: var(--v-theme-secondaryText);
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeInContent 0.3s ease;
|
||||
}
|
||||
|
||||
@keyframes fadeInContent {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -33,33 +33,53 @@
|
||||
<v-avatar class="bot-avatar" size="36">
|
||||
<v-progress-circular :index="index" v-if="isStreaming && index === messages.length - 1" indeterminate size="28"
|
||||
width="2"></v-progress-circular>
|
||||
<span v-else-if="messages[index - 1]?.content.type !== 'bot'" class="text-h2">✨</span>
|
||||
<v-icon v-else-if="messages[index - 1]?.content.type !== 'bot'" size="64" color="#8fb6d2">mdi-star-four-points-small</v-icon>
|
||||
</v-avatar>
|
||||
<div class="bot-message-content">
|
||||
<div class="message-bubble bot-bubble">
|
||||
<!-- Text -->
|
||||
<div v-if="msg.content.message && msg.content.message.trim()"
|
||||
v-html="md.render(msg.content.message)" class="markdown-content"></div>
|
||||
|
||||
<!-- Image -->
|
||||
<div class="embedded-images"
|
||||
v-if="msg.content.embedded_images && msg.content.embedded_images.length > 0">
|
||||
<div v-for="(img, imgIndex) in msg.content.embedded_images" :key="imgIndex"
|
||||
class="embedded-image">
|
||||
<img :src="img" class="bot-embedded-image"
|
||||
@click="$emit('openImagePreview', img)" />
|
||||
<!-- Loading state -->
|
||||
<div v-if="msg.content.isLoading" class="loading-container">
|
||||
<span class="loading-text">{{ tm('message.loading') }}</span>
|
||||
</div>
|
||||
|
||||
<template v-else>
|
||||
<!-- Reasoning Block (Collapsible) -->
|
||||
<div v-if="msg.content.reasoning && msg.content.reasoning.trim()" class="reasoning-container">
|
||||
<div class="reasoning-header" @click="toggleReasoning(index)">
|
||||
<v-icon size="small" class="reasoning-icon">
|
||||
{{ isReasoningExpanded(index) ? 'mdi-chevron-down' : 'mdi-chevron-right' }}
|
||||
</v-icon>
|
||||
<span class="reasoning-label">{{ tm('reasoning.thinking') }}</span>
|
||||
</div>
|
||||
<div v-if="isReasoningExpanded(index)" class="reasoning-content">
|
||||
<div v-html="md.render(msg.content.reasoning)" class="markdown-content reasoning-text"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Text -->
|
||||
<div v-if="msg.content.message && msg.content.message.trim()"
|
||||
v-html="md.render(msg.content.message)" class="markdown-content"></div>
|
||||
|
||||
<!-- Audio -->
|
||||
<div class="embedded-audio" v-if="msg.content.embedded_audio">
|
||||
<audio controls class="audio-player">
|
||||
<source :src="msg.content.embedded_audio" type="audio/wav">
|
||||
{{ t('messages.errors.browser.audioNotSupported') }}
|
||||
</audio>
|
||||
</div>
|
||||
<!-- Image -->
|
||||
<div class="embedded-images"
|
||||
v-if="msg.content.embedded_images && msg.content.embedded_images.length > 0">
|
||||
<div v-for="(img, imgIndex) in msg.content.embedded_images" :key="imgIndex"
|
||||
class="embedded-image">
|
||||
<img :src="img" class="bot-embedded-image"
|
||||
@click="$emit('openImagePreview', img)" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Audio -->
|
||||
<div class="embedded-audio" v-if="msg.content.embedded_audio">
|
||||
<audio controls class="audio-player">
|
||||
<source :src="msg.content.embedded_audio" type="audio/wav">
|
||||
{{ t('messages.errors.browser.audioNotSupported') }}
|
||||
</audio>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
<div class="message-actions">
|
||||
<div class="message-actions" v-if="!msg.content.isLoading">
|
||||
<v-btn :icon="getCopyIcon(index)" size="small" variant="text" class="copy-message-btn"
|
||||
:class="{ 'copy-success': isCopySuccess(index) }"
|
||||
@click="copyBotMessage(msg.content.message, index)" :title="t('core.common.copy')" />
|
||||
@@ -125,7 +145,8 @@ export default {
|
||||
copiedMessages: new Set(),
|
||||
isUserNearBottom: true,
|
||||
scrollThreshold: 1,
|
||||
scrollTimer: null
|
||||
scrollTimer: null,
|
||||
expandedReasoning: new Set(), // Track which reasoning blocks are expanded
|
||||
};
|
||||
},
|
||||
mounted() {
|
||||
@@ -142,6 +163,22 @@ export default {
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
// Toggle reasoning expansion state
|
||||
toggleReasoning(messageIndex) {
|
||||
if (this.expandedReasoning.has(messageIndex)) {
|
||||
this.expandedReasoning.delete(messageIndex);
|
||||
} else {
|
||||
this.expandedReasoning.add(messageIndex);
|
||||
}
|
||||
// Force reactivity
|
||||
this.expandedReasoning = new Set(this.expandedReasoning);
|
||||
},
|
||||
|
||||
// Check if reasoning is expanded
|
||||
isReasoningExpanded(messageIndex) {
|
||||
return this.expandedReasoning.has(messageIndex);
|
||||
},
|
||||
|
||||
// 复制代码到剪贴板
|
||||
copyCodeToClipboard(code) {
|
||||
navigator.clipboard.writeText(code).then(() => {
|
||||
@@ -348,7 +385,7 @@ export default {
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
to {
|
||||
@@ -539,6 +576,69 @@ export default {
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
/* Reasoning 区块样式 */
|
||||
.reasoning-container {
|
||||
margin-bottom: 12px;
|
||||
margin-top: 6px;
|
||||
border: 1px solid var(--v-theme-border);
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
width: fit-content;
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-container {
|
||||
background-color: rgba(103, 58, 183, 0.08);
|
||||
}
|
||||
|
||||
.reasoning-header {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
padding: 8px 8px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s ease;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.reasoning-header:hover {
|
||||
background-color: rgba(103, 58, 183, 0.08);
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-header:hover {
|
||||
background-color: rgba(103, 58, 183, 0.15);
|
||||
}
|
||||
|
||||
.reasoning-icon {
|
||||
margin-right: 6px;
|
||||
color: var(--v-theme-secondary);
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.reasoning-label {
|
||||
font-size: 13px;
|
||||
font-weight: 500;
|
||||
color: var(--v-theme-secondary);
|
||||
letter-spacing: 0.3px;
|
||||
}
|
||||
|
||||
.reasoning-content {
|
||||
padding: 0px 12px;
|
||||
border-top: 1px solid var(--v-theme-border);
|
||||
color: gray;
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.reasoning-text {
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
color: var(--v-theme-secondaryText);
|
||||
}
|
||||
|
||||
.v-theme--dark .reasoning-text {
|
||||
opacity: 0.85;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style>
|
||||
@@ -748,6 +848,29 @@ export default {
|
||||
margin: 10px 0;
|
||||
}
|
||||
|
||||
.loading-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 8px 0;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.loading-text {
|
||||
font-size: 14px;
|
||||
color: var(--v-theme-secondaryText);
|
||||
animation: pulse 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% {
|
||||
opacity: 0.6;
|
||||
}
|
||||
50% {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
border-left: 4px solid var(--v-theme-secondary);
|
||||
padding-left: 16px;
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- 选择提供商和模型按钮 -->
|
||||
<v-btn class="text-none" variant="tonal" rounded="xl" size="small"
|
||||
<v-chip class="text-none" variant="tonal" size="x-small"
|
||||
v-if="selectedProviderId && selectedModelName" @click="openDialog">
|
||||
{{ selectedProviderId }} / {{ selectedModelName }}
|
||||
</v-btn>
|
||||
<v-btn variant="tonal" rounded="xl" size="small" v-else @click="openDialog">
|
||||
</v-chip>
|
||||
<v-chip variant="tonal" rounded="xl" size="x-small" v-else @click="openDialog">
|
||||
选择模型
|
||||
</v-btn>
|
||||
</v-chip>
|
||||
|
||||
<!-- 选择提供商和模型对话框 -->
|
||||
<v-dialog v-model="showDialog" max-width="800" persistent>
|
||||
|
||||
@@ -154,7 +154,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<div class="w-100" v-if="!itemMeta?._special">
|
||||
<!-- Select input for JSON selector -->
|
||||
<v-select v-if="itemMeta?.options" v-model="createSelectorModel(itemKey).value"
|
||||
:items="itemMeta?.options" :disabled="itemMeta?.readonly" density="compact" variant="outlined"
|
||||
:items="itemMeta?.labels ? itemMeta.options.map((value, index) => ({ title: itemMeta.labels[index] || value, value: value })) : itemMeta.options"
|
||||
:disabled="itemMeta?.readonly" density="compact" variant="outlined"
|
||||
class="config-field" hide-details></v-select>
|
||||
|
||||
<!-- Code Editor for JSON selector -->
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user