Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
3b54a24037 feat: 初步实现可视化编辑 object 配置项 2025-02-22 17:12:35 +08:00
481 changed files with 10522 additions and 63005 deletions

3
.codecov.yml Normal file
View File

@@ -0,0 +1,3 @@
comment:
layout: "condensed_header, condensed_files, condensed_footer"
hide_project_coverage: TRUE

5
.coveragerc Normal file
View File

@@ -0,0 +1,5 @@
[run]
omit =
*/site-packages/*
*/dist-packages/*
your_package_name/tests/*

View File

@@ -17,8 +17,4 @@ ENV/
.conda/
README*.md
dashboard/
data/
changelogs/
tests/
.ruff_cache/
.astrbot
data/

15
.github/FUNDING.yml vendored
View File

@@ -1,15 +0,0 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: astrbot
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: ['https://afdian.com/a/astrbot_team']

View File

@@ -1,56 +0,0 @@
name: 🥳 发布插件
description: 提交插件到插件市场
title: "[Plugin] 插件名"
labels: ["plugin-publish"]
assignees: []
body:
- type: markdown
attributes:
value: |
欢迎发布插件到插件市场!
- type: markdown
attributes:
value: |
## 插件基本信息
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
- type: textarea
id: plugin-info
attributes:
label: 插件信息
description: 请在下方代码块中填写您的插件信息确保反引号包裹了JSON
value: |
```json
{
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
validations:
required: true
- type: markdown
attributes:
value: |
## 检查
- type: checkboxes
id: checks
attributes:
label: 插件检查清单
description: 请确认以下所有项目
options:
- label: 我的插件经过完整的测试
required: true
- label: 我的插件不包含恶意代码
required: true
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true

View File

@@ -28,7 +28,7 @@ body:
- type: textarea
attributes:
label: AstrBot 版本部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
label: AstrBot 版本部署方式
description: >
请提供您的 AstrBot 版本和部署方式。
placeholder: >
@@ -53,9 +53,9 @@ body:
- type: textarea
attributes:
label: 报错日志
label: 额外信息
description: >
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!
任何额外信息,如报错日志、截图等。
placeholder: >
请提供完整的报错日志或截图。
validations:
@@ -65,7 +65,7 @@ body:
attributes:
label: 你愿意提交 PR 吗?
description: >
这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
绝对不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
options:
- label: 是的,我愿意提交 PR!
@@ -79,4 +79,4 @@ body:
- type: markdown
attributes:
value: "感谢您填写我们的表单!"
value: "感谢您填写我们的表单!"

View File

@@ -1,46 +1,10 @@
<!-- 如果有的话,指定 PR 旨在解决的 ISSUE 编号。 -->
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
<!-- 如果有的话,指定这个 PR 解决的 ISSUE -->
修复了 #XYZ
fixes #XYZ
### Motivation
---
<!--解释为什么要改动-->
### Motivation / 动机
### Modifications
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
### Modifications / 改动点
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
### Verification Steps / 验证步骤
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤例如1. 导航到... 2. 点击...)。-->
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
### Screenshots or Test Results / 运行截图或测试结果
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
### Compatibility & Breaking Changes / 兼容性与破坏性变更
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
---
### Checklist / 检查清单
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
<!--简单解释你的改动-->

View File

@@ -1,36 +0,0 @@
# Set to true to add reviewers to pull requests
addReviewers: true
# Set to true to add assignees to pull requests
addAssignees: false
# A list of reviewers to be added to pull requests (GitHub user name)
reviewers:
- Soulter
- Raven95676
- Larch-C
- anka-afk
- advent259141
# - zouyonghe
# A number of reviewers added to the pull request
# Set 0 to add all the reviewers (default: 0)
numberOfReviewers: 2
# A list of assignees, overrides reviewers if set
# assignees:
# - assigneeA
# A number of assignees to add to the pull request
# Set to 0 to add all of the assignees.
# Uses numberOfReviewers if unset.
# numberOfAssignees: 2
# A list of keywords to be skipped the process that add reviewers if pull requests include it
skipKeywords:
- wip
- draft
# A list of users to be skipped by both the add reviewers and add assignees processes
# skipUsers:
# - dependabot[bot]

View File

@@ -1,63 +0,0 @@
# AstrBot Development Instructions
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
## Working Effectively
### Bootstrap and Install Dependencies
- **Python 3.10+ required** - Check `.python-version` file
- Install UV package manager: `pip install uv`
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
- Create required directories: `mkdir -p data/plugins data/config data/temp`
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
- Navigate to dashboard: `cd dashboard`
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
- Dashboard creates optimized production build in `dashboard/dist/`
### Testing
- Do not generate test files for now.
### Code Quality and Linting
- Install ruff linter: `uv add --dev ruff`
- Check code style: `uv run ruff check .` -- takes <1 second
- Check formatting: `uv run ruff format --check .` -- takes <1 second
- Fix formatting: `uv run ruff format .`
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
### Common Issues and Workarounds
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
## CI/CD Integration
- GitHub Actions workflows in `.github/workflows/`
- Docker builds supported via `Dockerfile`
- Pre-commit hooks enforce ruff formatting and linting
## Docker Support
- Primary deployment method: `docker run soulter/astrbot:latest`
- Compose file available: `compose.yml`
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
- Volume mount required: `./data:/AstrBot/data`
## Multi-language Support
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
- UI supports internationalization
- Default language is Chinese
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.

View File

@@ -1,13 +0,0 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly

View File

@@ -7,13 +7,13 @@ on:
name: Auto Release
jobs:
build-and-publish-to-github-release:
build:
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
- name: Dashboard Build
run: |
@@ -23,70 +23,13 @@ jobs:
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Upload to Cloudflare R2
env:
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
R2_BUCKET_NAME: "astrbot"
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
VERSION_TAG: ${{ github.ref_name }}
run: |
echo "Installing rclone..."
curl https://rclone.org/install.sh | sudo bash
echo "Configuring rclone remote..."
mkdir -p ~/.config/rclone
cat <<EOF > ~/.config/rclone/rclone.conf
[r2]
type = s3
provider = Cloudflare
access_key_id = $R2_ACCESS_KEY_ID
secret_access_key = $R2_SECRET_ACCESS_KEY
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
EOF
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
- name: Create GitHub Release
- name: Create Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"
build-and-publish-to-pypi:
# 构建并发布到 PyPI
runs-on: ubuntu-latest
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
- name: Install uv
run: |
python -m pip install uv
- name: Build package
run: |
uv build
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
uv publish
artifacts: "dashboard/dist.zip"

View File

@@ -1,34 +0,0 @@
name: Code Format Check
on:
pull_request:
branches: [ master ]
push:
branches: [ master ]
jobs:
format-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install UV
run: pip install uv
- name: Install dependencies
run: uv sync
- name: Check code formatting with ruff
run: |
uv run ruff format --check .
- name: Check code style with ruff
run: |
uv run ruff check .

View File

@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v4
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -1,6 +1,6 @@
name: Run tests and upload coverage
on:
on:
push:
branches:
- master
@@ -8,7 +8,6 @@ on:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
@@ -17,29 +16,30 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-asyncio pytest-cov
pip install --editable .
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
- name: Run tests
run: |
mkdir -p data/plugins
mkdir -p data/config
mkdir -p data/temp
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1,48 +0,0 @@
name: AstrBot Dashboard CI
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v5
- name: npm install, build
run: |
cd dashboard
npm install
npm run build
- name: Inject Commit SHA
id: get_sha
run: |
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
mkdir -p dashboard/dist/assets
echo $COMMIT_SHA > dashboard/dist/assets/version
cd dashboard
zip -r dist.zip dist
- name: Archive production artifacts
uses: actions/upload-artifact@v4
with:
name: dist-without-markdown
path: |
dashboard/dist
!dist/**/*.md
- name: Create GitHub Release
if: github.event_name == 'push'
uses: ncipollo/release-action@v1
with:
tag: release-${{ github.sha }}
owner: AstrBotDevs
repo: astrbot-release-harbour
body: "Automated release from commit ${{ github.sha }}"
token: ${{ secrets.ASTRBOT_HARBOUR_TOKEN }}
artifacts: "dashboard/dist.zip"

View File

@@ -11,79 +11,33 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Pull The Codes
uses: actions/checkout@v5
- name: 拉取源码
uses: actions/checkout@v3
with:
fetch-depth: 0 # Must be 0 so we can fetch tags
fetch-depth: 1
- name: Get latest tag (only on manual trigger)
id: get-latest-tag
if: github.event_name == 'workflow_dispatch'
run: |
tag=$(git describe --tags --abbrev=0)
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
- name: Checkout to latest tag (only on manual trigger)
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Check if version is pre-release
id: check-prerelease
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
else
version="${{ github.ref_name }}"
fi
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
echo "is_prerelease=true" >> $GITHUB_OUTPUT
echo "Version $version is a pre-release, will not push latest tag"
else
echo "is_prerelease=false" >> $GITHUB_OUTPUT
echo "Version $version is a stable release, will push latest tag"
fi
- name: Build Dashboard
run: |
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
- name: Set QEMU
- name: 设置 QEMU
uses: docker/setup-qemu-action@v3
- name: Set Docker Buildx
- name: 设置 Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to DockerHub
- name: 登录到 DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: Soulter
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Docker to DockerHub and Github GHCR
- name: 构建和推送 Docker hub
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v10
- uses: actions/stale@v5
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'

11
.gitignore vendored
View File

@@ -1,8 +1,6 @@
__pycache__
botpy.log
.vscode
.venv*
.idea
data_v2.db
data_v3.db
configs/session
@@ -19,15 +17,12 @@ addons/plugins
tests/astrbot_plugin_openai
chroma
dashboard/node_modules/
dashboard/dist/
node_modules/
.DS_Store
package-lock.json
package.json
venv/*
packages/python_interpreter/workplace
.venv/*
.conda/
.idea
pytest.ini
.astrbot
.conda/

View File

@@ -1,13 +0,0 @@
default_install_hook_types: [pre-commit, prepare-commit-msg]
ci:
autofix_commit_msg: ":balloon: auto fixes by pre-commit hooks"
autofix_prs: true
autoupdate_branch: master
autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format

View File

@@ -1 +0,0 @@
3.10

View File

@@ -1,35 +1,22 @@
FROM python:3.11-slim
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
nodejs \
npm \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
RUN python -m pip install -r requirements.txt --no-cache-dir
# 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
EXPOSE 6185
EXPOSE 6186
CMD [ "python", "main.py" ]

View File

@@ -1,35 +0,0 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]

View File

@@ -629,8 +629,8 @@ to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
AstrBot is a llm-powered chatbot and develop framework.
Copyright (C) 2022-2099 Soulter
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published

240
README.md
View File

@@ -1,198 +1,141 @@
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
<div align="center">
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
## 主要功能
## 主要功能
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
3. **Agent**完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富
5. **WebUI**。可视化配置和管理机器人,功能齐全
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper
2. **多消息平台接入**。支持接入 QQOneBot、QQ 频道、微信Gewechat、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
3. **Agent**原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件
5. **可视化管理面板**支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
## 部署方式
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM无法在聊天页使用大模型。不要再修改 demo 的登录密码了 😭)
## ✨ 使用方式
#### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### 1Panel 部署
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
#### 在 雨云 上部署
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### 在 Replit 上部署
社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### Windows 一键安装器部署
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
需要电脑上安装有 Python>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS 部署
社区贡献的部署方式。
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### 手动部署
首先安装 uv
```bash
pip install uv
```
通过 Git Clone 安装 AstrBot
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## 🌍 社区
### QQ 群组
- 1 群322154837
- 3 群630166526
- 5 群822130018
- 6 群753075035
- 开发者群975206796
- 开发者群备份295657329
### Telegram 群组
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群组
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## ⚡ 消息平台支持情况
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企业微信 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | |
| Discord | |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
| 飞书 | ✔ | 群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
## ⚡ 提供商支持情况
# 🦌 接下来的路线图
> [!TIP]
> 欢迎在 Issue 提出更多建议 <3
- [ ] 完善并保证目前所有平台适配器的功能一致性
- [ ] 优化插件接口
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
- [ ] 完善“聊天增强”部分,支持持久化记忆
- [ ] 规划 i18n
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | 文本生成 | |
| Google Gemini | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
### 如何贡献
对于新功能的添加,请先通过 Issue 讨论。
你可以通过查看问题或帮助审核 PR拉取请求来贡献。任何问题或 PR 都欢迎参与,以促进社区贡献。当然,这些只是建议,你可以以任何方式进行贡献。对于新功能的添加,请先通过 Issue 讨论。
## 🌟 支持
### 开发环境
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
AstrBot 使用 `ruff` 进行代码格式化和检查。
## ✨ Demo
```bash
git clone https://github.com/Soulter/AstrBot
pip install pre-commit
pre-commit install
```
> [!NOTE]
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
<div align='center'>
## ❤️ Special Thanks
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
_✨基于 Docker 的沙箱化代码执行器Beta 测试中✨_
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
此外,本项目的诞生离不开以下开源项目的帮助:
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
另外,一些同类型其他的活跃开源 Bot 项目:
_✨ 自然语言待办事项 ✨_
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ 管理面板 ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat在线与机器人交互 ✨_
</div>
## ⭐ Star History
@@ -205,8 +148,21 @@ pre-commit install
</div>
## Disclaimer
</details>
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
_私は、高性能ですから!_

View File

@@ -1,182 +0,0 @@
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
<div align="center">
_✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.app/">Documentation</a>
<a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
</div>
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
## ✨ Key Features
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
> [!TIP]
> Dashboard Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
> Username: `astrbot`, Password: `astrbot` (LLM not configured for chat page)
## ✨ Deployment
#### Docker Deployment
See docs: [Deploy with Docker](https://astrbot.app/deploy/astrbot/docker.html#docker-deployment)
#### Windows Installer
Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app/deploy/astrbot/windows.html)
#### Replit Deployment
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS Deployment
Community-contributed method.
See docs: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html)
#### Manual Deployment
See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
## ⚡ Platform Support
| Platform | Status | Details | Message Types |
| -------------------------------------------------------------- | ------ | ------------------- | ------------------- |
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| Feishu | ✔ | Group chats | Text, Images |
| WeChat Open Platform | 🚧 | Planned | - |
| Discord | 🚧 | Planned | - |
| WhatsApp | 🚧 | Planned | - |
| Xiaomi Speakers | 🚧 | Planned | - |
## Provider Support Status
| Name | Support | Type | Notes |
|---------------------------|---------|------------------------|-----------------------------------------------------------------------|
| OpenAI API | ✔ | Text Generation | Supports all OpenAI API-compatible services including DeepSeek, Google Gemini, GLM, Moonshot, Alibaba Cloud Bailian, Silicon Flow, xAI, etc. |
| Claude API | ✔ | Text Generation | |
| Google Gemini API | ✔ | Text Generation | |
| Dify | ✔ | LLMOps | |
| DashScope (Alibaba Cloud) | ✔ | LLMOps | |
| Ollama | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
| LM Studio | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
| LLMTuner | ✔ | Model Loader | Local loading of fine-tuned models (e.g. LoRA) |
| OneAPI | ✔ | LLM Distribution | |
| Whisper | ✔ | Speech-to-Text | Supports API and local deployment |
| SenseVoice | ✔ | Speech-to-Text | Local deployment |
| OpenAI TTS API | ✔ | Text-to-Speech | |
| Fishaudio | ✔ | Text-to-Speech | Project involving GPT-Sovits author |
# 🦌 Roadmap
> [!TIP]
> Suggestions welcome via Issues <3
- [ ] Ensure feature parity across all platform adapters
- [ ] Optimize plugin APIs
- [ ] Add default TTS services (e.g., GPT-Sovits)
- [ ] Enhance chat features with persistent memory
- [ ] i18n Planning
## ❤️ Contributions
All Issues/PRs welcome! Simply submit your changes to this project :)
For major features, please discuss via Issues first.
## 🌟 Support
- Star this project!
- Support via [Afdian](https://afdian.com/a/soulter)
- WeChat support: [QR Code](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)
## ✨ Demos
> [!NOTE]
> Code executor file I/O currently tested with Napcat(QQ)/Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨ Docker-based Sandboxed Code Executor (Beta) ✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ Multimodal Input, Web Search, Text-to-Image ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ Natural Language TODO Lists ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ Plugin System Showcase ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ Web Dashboard ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ Built-in Web Chat Interface ✨_
</div>
## ⭐ Star History
> [!TIP]
> If this project helps you, please give it a star <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
## Disclaimer
1. Licensed under `AGPL-v3`.
2. WeChat integration uses [Gewechat](https://github.com/Devo919/Gewechat). Use at your own risk with non-critical accounts.
3. Users must comply with local laws and regulations.
<!-- ## ✨ ATRI [Beta]
Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
2. Long-term memory
3. Meme understanding & responses
4. TTS integration
-->
_私は、高性能ですから!_

View File

@@ -1,5 +1,5 @@
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
@@ -27,15 +27,15 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
## ✨ 主な機能
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechatFeishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
> [!TIP]
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
>
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
## ✨ 使用方法
@@ -136,11 +136,11 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## ⭐ Star History
> [!TIP]
> [!TIP]
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
@@ -152,7 +152,8 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## 免責事項
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください
2. WeChat個人アカウントのデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
<!-- ## ✨ ATRI [ベータテスト]
@@ -164,4 +165,6 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
4. TTS
-->
_私は、高性能ですから!_

View File

@@ -1,3 +1,2 @@
from .core.log import LogManager
logger = LogManager.GetLogger(log_name="astrbot")
logger = LogManager.GetLogger(log_name='astrbot')

View File

@@ -3,18 +3,11 @@ from astrbot import logger
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_agent as agent
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
__all__ = [
"AstrBotConfig",
"logger",
"html_renderer",
"llm_tool",
"agent",
"sp",
"ToolSet",
"FunctionTool",
"BaseFunctionToolExecutor",
]
"sp"
]

View File

@@ -1,3 +1,4 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core import html_renderer
@@ -5,11 +6,8 @@ from astrbot.core.star.register import register_llm_tool as llm_tool
# event
from astrbot.core.message.message_event_result import (
MessageEventResult,
MessageChain,
CommandResult,
EventResultType,
)
MessageEventResult, MessageChain, CommandResult, EventResultType
)
from astrbot.core.platform import AstrMessageEvent
# star register
@@ -20,16 +18,10 @@ from astrbot.core.star.register import (
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
)
from astrbot.core.star.filter.event_message_type import (
EventMessageTypeFilter,
EventMessageType,
)
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterTypeFilter,
PlatformAdapterType,
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from astrbot.core.star.register import (
register_star as register, # 注册插件Star
register_star as register # 注册插件Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *
@@ -40,14 +32,9 @@ from astrbot.core.provider import Provider, Personality, ProviderMetaData
# platform
from astrbot.core.platform import (
AstrMessageEvent,
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from .message_components import *
from .message_components import *

View File

@@ -6,46 +6,36 @@ from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_platform_loaded as on_platform_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent,
register_after_message_sent as after_message_sent
)
from astrbot.core.star.filter.event_message_type import (
EventMessageTypeFilter,
EventMessageType,
)
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterTypeFilter,
PlatformAdapterType,
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
from astrbot.core.star.filter.custom_filter import CustomFilter
__all__ = [
"command",
"command_group",
"event_message_type",
"regex",
"platform_adapter_type",
"permission_type",
"EventMessageTypeFilter",
"EventMessageType",
"PlatformAdapterTypeFilter",
"PlatformAdapterType",
"PermissionTypeFilter",
"CustomFilter",
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_platform_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",
"after_message_sent",
"on_llm_response",
]
'command',
'command_group',
'event_message_type',
'regex',
'platform_adapter_type',
'permission_type',
'EventMessageTypeFilter',
'EventMessageType',
'PlatformAdapterTypeFilter',
'PlatformAdapterType',
'PermissionTypeFilter',
'CustomFilter',
'custom_filter',
'PermissionType',
'on_llm_request',
'llm_tool',
'on_decorating_result',
'after_message_sent',
'on_llm_response'
]

View File

@@ -1 +1 @@
from astrbot.core.message.components import *
from astrbot.core.message.components import *

View File

@@ -1,23 +1,6 @@
from astrbot.core.platform import (
AstrMessageEvent,
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
Group,
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
__all__ = [
"AstrMessageEvent",
"Platform",
"AstrBotMessage",
"MessageMember",
"MessageType",
"PlatformMetadata",
"register_platform_adapter",
"Group",
]
from astrbot.core.message.components import *

View File

@@ -1,17 +1,2 @@
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entities import (
ProviderRequest,
ProviderType,
ProviderMetaData,
LLMResponse,
)
__all__ = [
"Provider",
"STTProvider",
"Personality",
"ProviderRequest",
"ProviderType",
"ProviderMetaData",
"LLMResponse",
]
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse

View File

@@ -1,8 +1,6 @@
from astrbot.core.star.register import (
register_star as register, # 注册插件Star
register_star as register # 注册插件Star
)
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *
__all__ = ["register", "Context", "Star", "StarTools"]

View File

@@ -1,7 +0,0 @@
from astrbot.core.utils.session_waiter import (
SessionWaiter,
SessionController,
session_waiter,
)
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]

View File

@@ -1 +0,0 @@
__version__ = "3.5.23"

View File

@@ -1,59 +0,0 @@
"""
AstrBot CLI入口
"""
import click
import sys
from . import __version__
from .commands import init, run, plug, conf
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
@click.group()
@click.version_option(__version__, prog_name="AstrBot")
def cli() -> None:
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot CLI version: {__version__}")
@click.command()
@click.argument("command_name", required=False, type=str)
def help(command_name: str | None) -> None:
"""显示命令的帮助信息
如果提供了 COMMAND_NAME则显示该命令的详细帮助信息。
否则,显示通用帮助信息。
"""
ctx = click.get_current_context()
if command_name:
# 查找指定命令
command = cli.get_command(ctx, command_name)
if command:
# 显示特定命令的帮助信息
click.echo(command.get_help(ctx))
else:
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# 显示通用帮助信息
click.echo(cli.get_help(ctx))
cli.add_command(init)
cli.add_command(run)
cli.add_command(help)
cli.add_command(plug)
cli.add_command(conf)
if __name__ == "__main__":
cli()

View File

@@ -1,6 +0,0 @@
from .cmd_init import init
from .cmd_run import run
from .cmd_plug import plug
from .cmd_conf import conf
__all__ = ["init", "run", "plug", "conf"]

View File

@@ -1,206 +0,0 @@
import json
import click
import hashlib
import zoneinfo
from typing import Any, Callable
from ..utils import get_astrbot_root, check_astrbot_root
def _validate_log_level(value: str) -> str:
"""验证日志级别"""
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
)
return value
def _validate_dashboard_port(value: str) -> int:
"""验证 Dashboard 端口"""
try:
port = int(value)
if port < 1 or port > 65535:
raise click.ClickException("端口必须在 1-65535 范围内")
return port
except ValueError:
raise click.ClickException("端口必须是数字")
def _validate_dashboard_username(value: str) -> str:
"""验证 Dashboard 用户名"""
if not value:
raise click.ClickException("用户名不能为空")
return value
def _validate_dashboard_password(value: str) -> str:
"""验证 Dashboard 密码"""
if not value:
raise click.ClickException("密码不能为空")
return hashlib.md5(value.encode()).hexdigest()
def _validate_timezone(value: str) -> str:
"""验证时区"""
try:
zoneinfo.ZoneInfo(value)
except Exception:
raise click.ClickException(f"无效的时区: {value}请使用有效的IANA时区名称")
return value
def _validate_callback_api_base(value: str) -> str:
"""验证回调接口基址"""
if not value.startswith("http://") and not value.startswith("https://"):
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
return value
# 可通过CLI设置的配置项配置键到验证器函数的映射
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
"timezone": _validate_timezone,
"log_level": _validate_log_level,
"dashboard.port": _validate_dashboard_port,
"dashboard.username": _validate_dashboard_username,
"dashboard.password": _validate_dashboard_password,
"callback_api_base": _validate_callback_api_base,
}
def _load_config() -> dict[str, Any]:
"""加载或初始化配置文件"""
root = get_astrbot_root()
if not check_astrbot_root(root):
raise click.ClickException(
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
config_path = root / "data" / "cmd_config.json"
if not config_path.exists():
from astrbot.core.config.default import DEFAULT_CONFIG
config_path.write_text(
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
try:
return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e:
raise click.ClickException(f"配置文件解析失败: {str(e)}")
def _save_config(config: dict[str, Any]) -> None:
"""保存配置文件"""
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
)
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
"""设置嵌套字典中的值"""
parts = path.split(".")
for part in parts[:-1]:
if part not in obj:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
)
obj = obj[part]
obj[parts[-1]] = value
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
"""获取嵌套字典中的值"""
parts = path.split(".")
for part in parts:
obj = obj[part]
return obj
@click.group(name="conf")
def conf():
"""配置管理命令
支持的配置项:
- timezone: 时区设置 (例如: Asia/Shanghai)
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- dashboard.port: Dashboard 端口
- dashboard.username: Dashboard 用户名
- dashboard.password: Dashboard 密码
- callback_api_base: 回调接口基址
"""
pass
@conf.command(name="set")
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str):
"""设置配置项的值"""
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
config = _load_config()
try:
old_value = _get_nested_item(config, key)
validated_value = CONFIG_VALIDATORS[key](value)
_set_nested_item(config, key, validated_value)
_save_config(config)
click.echo(f"配置已更新: {key}")
if key == "dashboard.password":
click.echo(" 原值: ********")
click.echo(" 新值: ********")
else:
click.echo(f" 原值: {old_value}")
click.echo(f" 新值: {validated_value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"设置配置失败: {str(e)}")
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str = None):
"""获取配置项的值不提供key则显示所有可配置项"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
try:
value = _get_nested_item(config, key)
if key == "dashboard.password":
value = "********"
click.echo(f"{key}: {value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"获取配置失败: {str(e)}")
else:
click.echo("当前配置:")
for key in CONFIG_VALIDATORS.keys():
try:
value = (
"********"
if key == "dashboard.password"
else _get_nested_item(config, key)
)
click.echo(f" {key}: {value}")
except (KeyError, TypeError):
pass

View File

@@ -1,55 +0,0 @@
import asyncio
import click
from filelock import FileLock, Timeout
from ..utils import check_dashboard, get_astrbot_root
async def initialize_astrbot(astrbot_root) -> None:
"""执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
click.echo(f"Current Directory: {astrbot_root}")
click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
)
if click.confirm(
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
paths = {
"data": astrbot_root / "data",
"config": astrbot_root / "data" / "config",
"plugins": astrbot_root / "data" / "plugins",
"temp": astrbot_root / "data" / "temp",
}
for name, path in paths.items():
path.mkdir(parents=True, exist_ok=True)
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
await check_dashboard(astrbot_root / "data")
@click.command()
def init() -> None:
"""初始化 AstrBot"""
click.echo("Initializing AstrBot...")
astrbot_root = get_astrbot_root()
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
try:
with lock.acquire():
asyncio.run(initialize_astrbot(astrbot_root))
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"初始化失败: {e!s}")

View File

@@ -1,247 +0,0 @@
import re
from pathlib import Path
import click
import shutil
from ..utils import (
get_git_repo,
build_plug_list,
manage_plugin,
PluginStatus,
check_astrbot_root,
get_astrbot_root,
)
@click.group()
def plug():
"""插件管理"""
pass
def _get_data_path() -> Path:
base = get_astrbot_root()
if not check_astrbot_root(base):
raise click.ClickException(
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
return (base / "data").resolve()
def display_plugins(plugins, title=None, color=None):
if title:
click.echo(click.style(title, fg=color, bold=True))
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
click.echo("-" * 85)
for p in plugins:
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
click.echo(
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
f"{p['author']:<15} {desc:<30}"
)
@plug.command()
@click.argument("name")
def new(name: str):
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
if plug_path.exists():
raise click.ClickException(f"插件 {name} 已存在")
author = click.prompt("请输入插件作者", type=str)
desc = click.prompt("请输入插件描述", type=str)
version = click.prompt("请输入插件版本", type=str)
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
repo = click.prompt("请输入插件仓库:", type=str)
if not repo.startswith("http"):
raise click.ClickException("仓库地址必须以 http 开头")
click.echo("下载插件模板...")
get_git_repo(
"https://github.com/Soulter/helloworld",
plug_path,
)
click.echo("重写插件信息...")
# 重写 metadata.yaml
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
f.write(
f"name: {name}\n"
f"desc: {desc}\n"
f"version: {version}\n"
f"author: {author}\n"
f"repo: {repo}\n"
)
# 重写 README.md
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
# 重写 main.py
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
content = f.read()
new_content = content.replace(
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
f'@register("{name}", "{author}", "{desc}", "{version}")',
)
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
f.write(new_content)
click.echo(f"插件 {name} 创建成功")
@plug.command()
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
def list(all: bool):
"""列出插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
# 未发布的插件
not_published_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
]
if not_published_plugins:
display_plugins(not_published_plugins, "未发布的插件", "red")
# 需要更新的插件
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if need_update_plugins:
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
# 已安装的插件
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
if installed_plugins:
display_plugins(installed_plugins, "已安装的插件", "green")
# 未安装的插件
not_installed_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
]
if not_installed_plugins and all:
display_plugins(not_installed_plugins, "未安装的插件", "blue")
if (
not any([not_published_plugins, need_update_plugins, installed_plugins])
and not all
):
click.echo("未安装任何插件")
@plug.command()
@click.argument("name")
@click.option("--proxy", help="代理服务器地址")
def install(name: str, proxy: str | None):
"""安装插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
plugin = next(
(
p
for p in plugins
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
),
None,
)
if not plugin:
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
@plug.command()
@click.argument("name")
def remove(name: str):
"""卸载插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
plugin = next((p for p in plugins if p["name"] == name), None)
if not plugin or not plugin.get("local_path"):
raise click.ClickException(f"插件 {name} 不存在或未安装")
plugin_path = plugin["local_path"]
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
try:
shutil.rmtree(plugin_path)
click.echo(f"插件 {name} 已卸载")
except Exception as e:
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
@plug.command()
@click.argument("name", required=False)
@click.option("--proxy", help="Github代理地址")
def update(name: str, proxy: str | None):
"""更新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
if name:
plugin = next(
(
p
for p in plugins
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
),
None,
)
if not plugin:
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
else:
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if not need_update_plugins:
click.echo("没有需要更新的插件")
return
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
for plugin in need_update_plugins:
plugin_name = plugin["name"]
click.echo(f"正在更新插件 {plugin_name}...")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@plug.command()
@click.argument("query")
def search(query: str):
"""搜索插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
matched_plugins = [
p
for p in plugins
if query.lower() in p["name"].lower()
or query.lower() in p["desc"].lower()
or query.lower() in p["author"].lower()
]
if not matched_plugins:
click.echo(f"未找到匹配 '{query}' 的插件")
return
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")

View File

@@ -1,63 +0,0 @@
import os
import sys
from pathlib import Path
import click
import asyncio
import traceback
from filelock import FileLock, Timeout
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
async def run_astrbot(astrbot_root: Path):
"""运行 AstrBot"""
from astrbot.core import logger, LogManager, LogBroker, db_helper
from astrbot.core.initial_loader import InitialLoader
await check_dashboard(astrbot_root / "data")
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
db = db_helper
core_lifecycle = InitialLoader(db, log_broker)
await core_lifecycle.start()
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
@click.command()
def run(reload: bool, port: str) -> None:
"""运行 AstrBot"""
try:
os.environ["ASTRBOT_CLI"] = "1"
astrbot_root = get_astrbot_root()
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
sys.path.insert(0, str(astrbot_root))
if port:
os.environ["DASHBOARD_PORT"] = port
if reload:
click.echo("启用插件自动重载")
os.environ["ASTRBOT_RELOAD"] = "1"
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
with lock.acquire():
asyncio.run(run_astrbot(astrbot_root))
except KeyboardInterrupt:
click.echo("AstrBot 已关闭...")
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")

View File

@@ -1,18 +0,0 @@
from .basic import (
get_astrbot_root,
check_astrbot_root,
check_dashboard,
)
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
from .version_comparator import VersionComparator
__all__ = [
"get_astrbot_root",
"check_astrbot_root",
"check_dashboard",
"get_git_repo",
"manage_plugin",
"build_plug_list",
"VersionComparator",
"PluginStatus",
]

View File

@@ -1,76 +0,0 @@
from pathlib import Path
import click
def check_astrbot_root(path: str | Path) -> bool:
"""检查路径是否为 AstrBot 根目录"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
return False
if not (path / ".astrbot").exists():
return False
return True
def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径"""
return Path.cwd()
async def check_dashboard(astrbot_root: Path) -> None:
"""检查是否安装了dashboard"""
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
from astrbot.core.config.default import VERSION
from .version_comparator import VersionComparator
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo("未安装管理面板")
if click.confirm(
"是否安装管理面板?",
default=True,
abort=True,
):
click.echo("正在安装管理面板...")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
click.echo("管理面板安装完成")
case str():
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("管理面板已是最新版本")
return
else:
try:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return
except FileNotFoundError:
click.echo("初始化管理面板目录...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"),
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
click.echo("管理面板初始化完成")
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return

View File

@@ -1,237 +0,0 @@
import shutil
import tempfile
import httpx
import yaml
from enum import Enum
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile
import click
from .version_comparator import VersionComparator
class PluginStatus(str, Enum):
INSTALLED = "已安装"
NEED_UPDATE = "需更新"
NOT_INSTALLED = "未安装"
NOT_PUBLISHED = "未发布"
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
"""从 Git 仓库下载代码并解压到指定路径"""
temp_dir = Path(tempfile.mkdtemp())
try:
# 解析仓库信息
repo_namespace = url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
# 尝试获取最新的 release
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try:
with httpx.Client(
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(release_url)
resp.raise_for_status()
releases = resp.json()
if releases:
# 使用最新的 release
download_url = releases[0]["zipball_url"]
else:
# 没有 release使用默认分支
click.echo(f"正在从默认分支下载 {author}/{repo}")
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
except Exception as e:
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
download_url = url
# 应用代理
if proxy:
download_url = f"{proxy}/{download_url}"
# 下载并解压
with httpx.Client(
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(download_url)
if (
resp.status_code == 404
and "archive/refs/heads/master.zip" in download_url
):
alt_url = download_url.replace("master.zip", "main.zip")
click.echo("master 分支不存在,尝试下载 main 分支")
resp = client.get(alt_url)
resp.raise_for_status()
else:
resp.raise_for_status()
zip_content = BytesIO(resp.content)
with ZipFile(zip_content) as z:
z.extractall(temp_dir)
namelist = z.namelist()
root_dir = Path(namelist[0]).parts[0] if namelist else ""
if target_path.exists():
shutil.rmtree(target_path)
shutil.move(temp_dir / root_dir, target_path)
finally:
if temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)
def load_yaml_metadata(plugin_dir: Path) -> dict:
"""从 metadata.yaml 文件加载插件元数据
Args:
plugin_dir: 插件目录路径
Returns:
dict: 包含元数据的字典,如果读取失败则返回空字典
"""
yaml_path = plugin_dir / "metadata.yaml"
if yaml_path.exists():
try:
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
except Exception as e:
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
return {}
def build_plug_list(plugins_dir: Path) -> list:
"""构建插件列表,包含本地和在线插件信息
Args:
plugins_dir (Path): 插件目录路径
Returns:
list: 包含插件信息的字典列表
"""
# 获取本地插件信息
result = []
if plugins_dir.exists():
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
# 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
result.append(
{
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
}
)
# 获取在线插件列表
online_plugins = []
try:
with httpx.Client() as client:
resp = client.get("https://api.soulter.top/astrbot/plugins")
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append(
{
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
}
)
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)
# 与在线插件比对,更新状态
online_plugin_names = {plugin["name"] for plugin in online_plugins}
for local_plugin in result:
if local_plugin["name"] in online_plugin_names:
# 查找对应的在线插件
online_plugin = next(
p for p in online_plugins if p["name"] == local_plugin["name"]
)
if (
VersionComparator.compare_version(
local_plugin["version"], online_plugin["version"]
)
< 0
):
local_plugin["status"] = PluginStatus.NEED_UPDATE
else:
# 本地插件未在线上发布
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
# 添加未安装的在线插件
for online_plugin in online_plugins:
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
result.append(online_plugin)
return result
def manage_plugin(
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
) -> None:
"""安装或更新插件
Args:
plugin (dict): 插件信息字典
plugins_dir (Path): 插件目录
is_update (bool, optional): 是否为更新操作. 默认为 False
proxy (str, optional): 代理服务器地址
"""
plugin_name = plugin["name"]
repo_url = plugin["repo"]
# 如果是更新且有本地路径,直接使用本地路径
if is_update and plugin.get("local_path"):
target_path = Path(plugin["local_path"])
else:
target_path = plugins_dir / plugin_name
backup_path = Path(f"{target_path}_backup") if is_update else None
# 检查插件是否存在
if is_update and not target_path.exists():
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
# 备份现有插件
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
if is_update:
shutil.copytree(target_path, backup_path)
try:
click.echo(
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
)
get_git_repo(repo_url, target_path, proxy)
# 更新成功,删除备份
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
)

View File

@@ -1,92 +0,0 @@
"""
拷贝自 astrbot.core.utils.version_comparator
"""
import re
class VersionComparator:
@staticmethod
def compare_version(v1: str, v2: str) -> int:
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
参考: https://semver.org/lang/zh-CN/
返回 1 表示 v1 > v2返回 -1 表示 v1 < v2返回 0 表示 v1 = v2。
"""
v1 = v1.lower().replace("v", "")
v2 = v2.lower().replace("v", "")
def split_version(version):
match = re.match(
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
version,
)
if not match:
return [], None
major_minor_patch = match.group(1).split(".")
prerelease = match.group(2)
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
parts = [int(x) for x in major_minor_patch]
prerelease = VersionComparator._split_prerelease(prerelease)
return parts, prerelease
v1_parts, v1_prerelease = split_version(v1)
v2_parts, v2_prerelease = split_version(v2)
# 比较数字部分
length = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (length - len(v1_parts)))
v2_parts.extend([0] * (length - len(v2_parts)))
for i in range(length):
if v1_parts[i] > v2_parts[i]:
return 1
elif v1_parts[i] < v2_parts[i]:
return -1
# 比较预发布标签
if v1_prerelease is None and v2_prerelease is not None:
return 1 # 没有预发布标签的版本高于有预发布标签的版本
elif v1_prerelease is not None and v2_prerelease is None:
return -1 # 有预发布标签的版本低于没有预发布标签的版本
elif v1_prerelease is not None and v2_prerelease is not None:
len_pre = max(len(v1_prerelease), len(v2_prerelease))
for i in range(len_pre):
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
if p1 is None and p2 is not None:
return -1
elif p1 is not None and p2 is None:
return 1
elif isinstance(p1, int) and isinstance(p2, str):
return -1
elif isinstance(p1, str) and isinstance(p2, int):
return 1
elif isinstance(p1, int) and isinstance(p2, int):
if p1 > p2:
return 1
elif p1 < p2:
return -1
elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
elif p1 < p2:
return -1
return 0 # 预发布标签完全相同
return 0 # 数字部分和预发布标签都相同
@staticmethod
def _split_prerelease(prerelease):
if not prerelease:
return None
parts = prerelease.split(".")
result = []
for part in parts:
if part.isdigit():
result.append(int(part))
else:
result.append(part)
return result

View File

@@ -1,29 +1,26 @@
import os
from .log import LogManager, LogBroker # noqa
import asyncio
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
from astrbot.core.file_token_service import FileTokenService
from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
DEMO_MODE = os.getenv("DEMO_MODE", False)
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
t2i_base_url = astrbot_config.get('t2i_endpoint', 'https://t2i.soulter.top/text2img')
html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
logger = LogManager.GetLogger(log_name='astrbot')
if os.environ.get('TESTING', ""):
logger.setLevel('DEBUG')
db_helper = SQLiteDatabase(DB_PATH)
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences(db_helper=db_helper)
# 文件令牌服务
file_token_service = FileTokenService()
pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None),
)
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
from .tool import FunctionTool
from typing import Generic
from .run_context import TContext
from .hooks import BaseAgentRunHooks
@dataclass
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -1,34 +0,0 @@
from typing import Generic
from .tool import FunctionTool
from .agent import Agent
from .run_context import TContext
class HandoffTool(FunctionTool, Generic[TContext]):
"""Handoff tool for delegating tasks to another agent."""
def __init__(
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
):
self.agent = agent
super().__init__(
name=f"transfer_to_{agent.name}",
parameters=parameters or self.default_parameters(),
description=agent.instructions or self.default_description(agent.name),
**kwargs,
)
def default_parameters(self) -> dict:
return {
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "The input to be handed off to another agent. This should be a clear and concise request or task.",
},
},
}
def default_description(self, agent_name: str | None) -> str:
agent_name = agent_name or "another"
return f"Delegate tasks to {self.name} agent to handle the request."

View File

@@ -1,27 +0,0 @@
import mcp
from dataclasses import dataclass
from .run_context import ContextWrapper, TContext
from typing import Generic
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.agent.tool import FunctionTool
@dataclass
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
): ...
async def on_tool_end(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
tool_result: mcp.types.CallToolResult | None,
): ...
async def on_agent_done(
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
): ...

View File

@@ -1,208 +0,0 @@
import asyncio
import logging
from datetime import timedelta
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core.utils.log_pipe import LogPipe
try:
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在:
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
# 处理 MCP 服务的错误日志
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
), # type: ignore
),
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData

View File

@@ -1,18 +0,0 @@
from dataclasses import dataclass
from typing import Any, Generic
from typing_extensions import TypeVar
from astrbot.core.platform.astr_message_event import AstrMessageEvent
TContext = TypeVar("TContext", default=Any)
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
event: AstrMessageEvent
NoContext = ContextWrapper[None]

View File

@@ -1,3 +0,0 @@
from .base import BaseAgentRunner
__all__ = ["BaseAgentRunner"]

View File

@@ -1,58 +0,0 @@
import abc
import typing as T
from enum import Enum, auto
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponse
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import LLMResponse
class AgentState(Enum):
"""Defines the state of the agent."""
IDLE = auto() # Initial state
RUNNING = auto() # Currently processing
DONE = auto() # Completed
ERROR = auto() # Error state
class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
provider: Provider,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
"""
Reset the agent to its initial state.
This method should be called before starting a new run.
"""
...
@abc.abstractmethod
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""
Process a single step of the agent.
"""
...
@abc.abstractmethod
def done(self) -> bool:
"""
Check if the agent has completed its task.
Returns True if the agent is done, False otherwise.
"""
...
@abc.abstractmethod
def get_final_llm_resp(self) -> LLMResponse | None:
"""
Get the final observation from the agent.
This method should be called after the agent is done.
"""
...

View File

@@ -1,334 +0,0 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentState
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponseData
from astrbot.core.provider.provider import Provider
from astrbot.core.message.message_event_result import (
MessageChain,
)
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from mcp.types import (
TextContent,
ImageContent,
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
CallToolResult,
)
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.provider = provider
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
self.agent_hooks = agent_hooks
self.run_context = run_context
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__)
async for resp in stream: # type: ignore
yield resp
else:
yield await self.provider.text_chat(**self.req.__dict__)
@override
async def step(self):
"""
Process a single step of the agent.
This method should return the result of the step.
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text)
),
)
continue
llm_resp_result = llm_response
break # got final response
if not llm_resp_result:
return
# 处理 LLM 响应
llm_resp = llm_resp_result
if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.ERROR)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
)
),
)
if not llm_resp.tools_call_name:
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回 LLM 结果
if llm_resp.result_chain:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=llm_resp.result_chain),
)
elif llm_resp.completion_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text)
),
)
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
yield AgentResponse(
type="tool_call_result",
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
)
self.req.append_tool_calls_result(tool_calls_result)
async def _handle_function_tools(
self,
req: ProviderRequest,
llm_response: LLMResponse,
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
"""处理函数工具调用。"""
tool_call_result_blocks: list[ToolCallMessageSegment] = []
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
# 执行函数调用
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
try:
await self.agent_hooks.on_tool_start(
self.run_context, func_tool, func_tool_args
)
except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
executor = self.tool_executor.execute(
tool=func_tool,
run_context=self.run_context,
**func_tool_args,
)
async for resp in executor:
if isinstance(resp, CallToolResult):
res = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(
type="tool_direct_result"
).base64_image(res.content[0].data)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.run_context.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else:
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
self.run_context.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
# 处理函数调用响应
if tool_call_result_blocks:
yield tool_call_result_blocks
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -1,256 +0,0 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from .mcp_client import MCPClient
@dataclass
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str | None = None
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str | None = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient | None = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
def __dict__(self) -> dict[str, Any]:
"""将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
class ToolSet:
"""A set of function tools that can be used in function calling.
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool):
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
if existing_tool.name == tool.name:
self.tools[i] = tool
return
self.tools.append(tool)
def remove_tool(self, name: str):
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
def get_tool(self, name: str) -> Optional[FunctionTool]:
"""Get a tool by its name."""
for tool in self.tools:
if tool.name == name:
return tool
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FunctionTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.add_tool(_func)
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
def remove_func(self, name: str):
"""Remove a function tool by its name."""
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
"""Get all function tools."""
return self.get_tool(name)
@property
def func_list(self) -> list[FunctionTool]:
"""Get the list of function tools."""
return self.tools
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
"""Convert tools to OpenAI API function calling schema format."""
result = []
for tool in self.tools:
func_def = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
},
}
if tool.parameters.get("properties") or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result
def anthropic_schema(self) -> list[dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
}
result.append(tool_def)
return result
def google_schema(self) -> dict:
"""Convert tools to Google GenAI API format."""
def convert_schema(schema: dict) -> dict:
"""Convert schema to Gemini API format."""
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties:
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
return self.openai_schema(omit_empty_parameter_field)
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
def get_func_desc_anthropic_style(self):
return self.anthropic_schema()
@deprecated(reason="Use google_schema() instead", version="4.0.0")
def get_func_desc_google_genai_style(self):
return self.google_schema()
def names(self) -> list[str]:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]
def __len__(self):
return len(self.tools)
def __bool__(self):
return len(self.tools) > 0
def __iter__(self):
return iter(self.tools)
def __repr__(self):
return f"ToolSet(tools={self.tools})"
def __str__(self):
return f"ToolSet(tools={self.tools})"

View File

@@ -1,11 +0,0 @@
import mcp
from typing import Any, Generic, AsyncGenerator
from .run_context import TContext, ContextWrapper
from .tool import FunctionTool
class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod
async def execute(
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...

View File

@@ -1,11 +0,0 @@
from dataclasses import dataclass
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@dataclass
class AstrAgentContext:
provider: Provider
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool

View File

@@ -1,283 +0,0 @@
import os
import uuid
from astrbot.core import AstrBotConfig, logger
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from typing import TypeVar, TypedDict
_VT = TypeVar("_VT")
class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default"
umop: list[str] # Unified Message Origin Pattern
name: str
path: str # File name to the configuration file
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
id="default",
umop=["::"],
name="default",
path=ASTRBOT_CONFIG_PATH,
)
class AstrBotConfigManager:
"""A class to manage the system configuration of AstrBot, aka ACM"""
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
self.sp = sp
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
self.abconf_data = None
self._load_all_configs()
def _get_abconf_data(self) -> dict:
"""获取所有的 abconf 数据"""
if self.abconf_data is None:
self.abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
return self.abconf_data
def _load_all_configs(self):
"""Load all configurations from the shared preferences."""
abconf_data = self._get_abconf_data()
self.abconf_data = abconf_data
for uuid_, meta in abconf_data.items():
filename = meta["path"]
conf_path = os.path.join(get_astrbot_config_path(), filename)
if os.path.exists(conf_path):
conf = AstrBotConfig(config_path=conf_path)
self.confs[uuid_] = conf
else:
logger.warning(
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
)
continue
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "umop": list, "path": str, "name": str }
abconf_data = self._get_abconf_data()
if isinstance(umo, MessageSession):
umo = str(umo)
else:
try:
umo = str(MessageSession.from_str(umo)) # validate
except Exception:
return DEFAULT_CONFIG_CONF_INFO
for uuid_, meta in abconf_data.items():
for pattern in meta["umop"]:
if self._is_umo_match(pattern, umo):
return ConfInfo(**meta, id=uuid_)
return DEFAULT_CONFIG_CONF_INFO
def _save_conf_mapping(
self,
abconf_path: str,
abconf_id: str,
umo_parts: list[str] | list[MessageSession],
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"umop": umo_parts,
"path": abconf_path,
"name": random_word,
}
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
if not umo:
return self.confs["default"]
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
uuid_ = self._load_conf_mapping(umo)["id"]
conf = self.confs.get(uuid_)
if not conf:
conf = self.confs["default"] # default MUST exists
return conf
@property
def default_conf(self) -> AstrBotConfig:
"""获取默认配置文件"""
return self.confs["default"]
def get_conf_info(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件元数据"""
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
return self._load_conf_mapping(umo)
def get_conf_list(self) -> list[ConfInfo]:
"""获取所有配置文件的元数据列表"""
conf_list = []
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
abconf_mapping = self._get_abconf_data()
for uuid_, meta in abconf_mapping.items():
conf_list.append(ConfInfo(**meta, id=uuid_))
return conf_list
def create_conf(
self,
umo_parts: list[str] | list[MessageSession],
config: dict = DEFAULT_CONFIG,
name: str | None = None,
) -> str:
"""
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
"""
conf_uuid = str(uuid.uuid4())
conf_file_name = f"abconf_{conf_uuid}.json"
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
conf = AstrBotConfig(config_path=conf_path, default_config=config)
conf.save_config()
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
self.confs[conf_uuid] = conf
return conf_uuid
def delete_conf(self, conf_id: str) -> bool:
"""删除指定配置文件
Args:
conf_id: 配置文件的 UUID
Returns:
bool: 删除是否成功
Raises:
ValueError: 如果试图删除默认配置文件
"""
if conf_id == "default":
raise ValueError("不能删除默认配置文件")
# 从映射中移除
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 获取配置文件路径
conf_path = os.path.join(
get_astrbot_config_path(), abconf_data[conf_id]["path"]
)
# 删除配置文件
try:
if os.path.exists(conf_path):
os.remove(conf_path)
logger.info(f"已删除配置文件: {conf_path}")
except Exception as e:
logger.error(f"删除配置文件 {conf_path} 失败: {e}")
return False
# 从内存中移除
if conf_id in self.confs:
del self.confs[conf_id]
# 从映射中移除
del abconf_data[conf_id]
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功删除配置文件 {conf_id}")
return True
def update_conf_info(
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
) -> bool:
"""更新配置文件信息
Args:
conf_id: 配置文件的 UUID
name: 新的配置文件名称 (可选)
umo_parts: 新的 UMO 部分列表 (可选)
Returns:
bool: 更新是否成功
"""
if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息")
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 更新名称
if name is not None:
abconf_data[conf_id]["name"] = name
# 更新 UMO 部分
if umo_parts is not None:
# 验证 UMO 部分格式
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data[conf_id]["umop"] = umo_parts
# 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功更新配置文件 {conf_id} 的信息")
return True
def g(
self, umo: str | None = None, key: str | None = None, default: _VT = None
) -> _VT:
"""获取配置项。umo 为 None 时使用默认配置"""
if umo is None:
return self.confs["default"].get(key, default)
conf = self.get_conf(umo)
return conf.get(key, default)

View File

@@ -1,9 +1,2 @@
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
from .astrbot_config import *
__all__ = [
"DEFAULT_CONFIG",
"VERSION",
"DB_PATH",
"AstrBotConfig",
]
from .astrbot_config import *

View File

@@ -4,158 +4,115 @@ import logging
import enum
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
logger = logging.getLogger("astrbot")
class RateLimitStrategy(enum.Enum):
STALL = "stall"
DISCARD = "discard"
class AstrBotConfig(dict):
"""从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
'''从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
- 如果传入了 schema将会通过 schema 解析出 default_config此时传入的 default_config 会被忽略。
"""
'''
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None,
schema: dict = None
):
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
object.__setattr__(self, "config_path", config_path)
object.__setattr__(self, "default_config", default_config)
object.__setattr__(self, "schema", schema)
object.__setattr__(self, 'config_path', config_path)
object.__setattr__(self, 'default_config', default_config)
object.__setattr__(self, 'schema', schema)
if schema:
default_config = self._config_schema_to_default_config(schema)
if not self.check_exist():
"""不存在时载入默认配置"""
'''不存在时载入默认配置'''
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(default_config, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def _config_schema_to_default_config(self, schema: dict) -> dict:
"""将 Schema 转换成 Config"""
'''将 Schema 转换成 Config'''
conf = {}
def _parse_schema(schema: dict, conf: dict):
for k, v in schema.items():
if v["type"] not in DEFAULT_VALUE_MAP:
raise TypeError(
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}"
)
if "default" in v:
default = v["default"]
if v['type'] not in DEFAULT_VALUE_MAP:
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
if 'default' in v:
default = v['default']
else:
default = DEFAULT_VALUE_MAP[v["type"]]
if v["type"] == "object":
default = DEFAULT_VALUE_MAP[v['type']]
if v['type'] == 'object':
conf[k] = {}
_parse_schema(v["items"], conf[k])
_parse_schema(v['items'], conf[k])
else:
conf[k] = default
_parse_schema(schema, conf)
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
# 创建一个新的有序字典以保持参考配置的顺序
new_conf = {}
# 先按照参考配置的顺序添加配置项
for key, value in refer_conf.items():
if key not in conf:
# 配置项不存在,插入默认值
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,插入默认值 {value}")
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
new_conf[key] = value
conf[key] = value
has_new = True
else:
if conf[key] is None:
# 配置项为 None使用默认值
new_conf[key] = value
conf[key] = value
has_new = True
elif isinstance(value, dict):
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
new_conf[key] = value
has_new = True
else:
# 递归检查并同步顺序
child_has_new = self.check_config_integrity(
value, conf[key], path + "." + key if path else key
)
new_conf[key] = conf[key]
has_new |= child_has_new
else:
# 直接使用现有配置
new_conf[key] = conf[key]
# 检查是否存在参考配置中没有的配置项
for key in list(conf.keys()):
if key not in refer_conf:
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
has_new = True
# 顺序不一致也算作变更
if list(conf.keys()) != list(new_conf.keys()):
if path:
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
else:
logger.info("检查到配置项顺序不一致,已重新排序")
has_new = True
# 更新原始配置
conf.clear()
conf.update(new_conf)
has_new |= self.check_config_integrity(value, conf[key], path + "." + key if path else key)
return has_new
def save_config(self, replace_config: Dict = None):
"""将配置写入文件
'''将配置写入文件
如果传入 replace_config则将配置替换为 replace_config
"""
'''
if replace_config:
self.update(replace_config)
with open(self.config_path, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
return None
def __delattr__(self, key):
try:
del self[key]
@@ -167,4 +124,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(self.config_path)
return os.path.exists(self.config_path)

File diff suppressed because it is too large Load Diff

View File

@@ -1,287 +1,109 @@
"""
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
from astrbot.core.db.po import Conversation
class ConversationManager():
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: Dict[str, str] = {}
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
asyncio.create_task(self._periodic_save())
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
updated_at = int(conv_v2.updated_at.timestamp())
return Conversation(
platform_id=conv_v2.platform_id,
user_id=conv_v2.user_id,
cid=conv_v2.conversation_id,
history=json.dumps(conv_v2.content or []),
title=conv_v2.title,
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
)
async def _periodic_save(self):
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
async def new_conversation(
self,
unified_msg_origin: str,
platform_id: str | None = None,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> str:
"""新建对话,并将当前会话的对话转移到新对话
def _save_to_storage(self):
sp.put("session_conversation", self.session_conversations)
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
if not platform_id:
# 如果没有提供 platform_id则从 unified_msg_origin 中解析
parts = unified_msg_origin.split(":")
if len(parts) >= 3:
platform_id = parts[0]
if not platform_id:
platform_id = "unknown"
conv = await self.db.create_conversation(
async def new_conversation(self, unified_msg_origin: str) -> str:
'''新建对话,并将当前会话的对话转移到新对话'''
conversation_id = str(uuid.uuid4())
self.db.new_conversation(
user_id=unified_msg_origin,
platform_id=platform_id,
content=content,
title=title,
persona_id=persona_id,
cid=conversation_id
)
self.session_conversations[unified_msg_origin] = conv.conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str | None = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
'''切换会话的对话'''
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
if f:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
ret = self.session_conversations.get(unified_msg_origin, None)
if not ret:
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
if ret:
self.session_conversations[unified_msg_origin] = ret
return ret
async def get_conversation(
self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation | None:
"""获取会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Returns:
conversation (Conversation): 对话对象
"""
conv = await self.db.get_conversation_by_id(cid=conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
conv = await self.db.get_conversation_by_id(cid=conversation_id)
conv_res = None
if conv:
conv_res = self._convert_conv_from_v2_to_v1(conv)
return conv_res
async def get_conversations(
self, unified_msg_origin: str | None = None, platform_id: str | None = None
) -> List[Conversation]:
"""获取对话列表
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id可选
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
Returns:
conversations (List[Conversation]): 对话对象列表
"""
convs = await self.db.get_conversations(
user_id=unified_msg_origin, platform_id=platform_id
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res
async def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[Conversation], int]:
"""获取过滤后的对话列表
Args:
page (int): 页码, 默认为 1
page_size (int): 每页大小, 默认为 20
platform_ids (list[str]): 平台 ID 列表, 可选
search_query (str): 搜索查询字符串, 可选
Returns:
conversations (list[Conversation]): 对话对象列表
"""
convs, cnt = await self.db.get_filtered_conversations(
page=page,
page_size=page_size,
platform_ids=platform_ids,
search_query=search_query,
**kwargs,
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res, cnt
async def update_conversation(
self,
unified_msg_origin: str,
conversation_id: str | None = None,
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
):
"""更新会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if not conversation_id:
# 如果没有提供 conversation_id则获取当前的
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
if conversation_id:
await self.db.update_conversation(
cid=conversation_id,
title=title,
persona_id=persona_id,
content=history,
self.db.delete_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
async def update_conversation_title(
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
):
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
Deprecated:
Use `update_conversation` with `title` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
title=title,
)
async def update_conversation_persona_id(
self,
unified_msg_origin: str,
persona_id: str,
conversation_id: str | None = None,
):
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
Deprecated:
Use `update_conversation` with `persona_id` parameter instead.
"""
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
persona_id=persona_id,
)
async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10
):
"""获取人类可读的上下文
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
del self.session_conversations[unified_msg_origin]
sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
'''获取会话当前的对话 ID'''
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
'''获取会话的对话'''
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
'''获取会话的所有对话'''
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
'''更新会话的对话'''
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
history=json.dumps(history)
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
'''更新会话的对话标题'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin,
cid=conversation_id,
title=title
)
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
'''更新会话的对话 Persona ID'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin,
cid=conversation_id,
persona_id=persona_id
)
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
contexts = []
temp_contexts = []
for record in history:
if record["role"] == "user":
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
if "content" in record and record["content"]:
temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record:
tool_calls_str = json.dumps(
record["tool_calls"], ensure_ascii=False
)
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else:
temp_contexts.append("Assistant: [未知的内容]")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
@@ -289,9 +111,9 @@ class ConversationManager:
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
return paged_contexts, total_pages

View File

@@ -1,241 +1,122 @@
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
工作流程:
1. 初始化所有组件
2. 启动事件总线和任务, 所有任务都在这里运行
3. 执行启动完成事件钩子
"""
import traceback
import asyncio
import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config, html_renderer
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, sp
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle:
"""
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker # 初始化日志代理
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
# 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "")
if proxy_config != "":
os.environ["https_proxy"] = proxy_config
os.environ["http_proxy"] = proxy_config
logger.debug(f"Using proxy: {proxy_config}")
# 设置 no_proxy
no_proxy_list = self.astrbot_config.get("no_proxy", [])
os.environ["no_proxy"] = ",".join(no_proxy_list)
else:
# 清空代理环境变量
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
logger.debug("HTTP proxy cleared")
self.log_broker = log_broker
self.astrbot_config = astrbot_config
self.db = db
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
os.environ['no_proxy'] = 'localhost,127.0.0.1'
async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
logger.info("AstrBot v" + VERSION)
logger.info("AstrBot v"+ VERSION)
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
logger.setLevel("DEBUG")
else:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
await self.db.initialize()
await html_renderer.initialize()
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config, sp=sp
)
# 初始化事件队列
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
# 初始化人格管理器
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
await self.persona_mgr.initialize()
# 初始化供应商管理器
self.provider_manager = ProviderManager(
self.astrbot_config_mgr, self.db, self.persona_mgr
)
# 初始化平台管理器
self.event_queue.closed = False
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
# 初始化对话管理器
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.conversation_manager = ConversationManager(self.db)
# 初始化平台消息历史管理器
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.platform_message_history_manager,
self.persona_mgr,
self.astrbot_config_mgr,
self.knowledge_db_manager
)
# 初始化插件管理器
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
# 扫描、注册插件、实例化插件类
await self.plugin_manager.reload()
# 根据配置实例化各个 Provider
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
'''根据配置实例化各个 Provider'''
await self.platform_manager.initialize()
'''根据配置实例化各个平台适配器'''
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
self.event_bus = EventBus(
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
)
# 记录启动时间
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
await self.pipeline_scheduler.initialize()
'''初始化消息事件流水线调度器'''
self.astrbot_updator = AstrBotUpdator(self.astrbot_config['plugin_repo_mirror'])
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
def _load(self):
"""加载事件总线和任务并初始化"""
# 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task(
self.event_bus.dispatch(), name="event_bus"
)
# 把插件中注册的所有协程函数注册到事件总线中并执行
platform_tasks = self.load_platform()
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
tasks_ = [event_bus_task, *extra_tasks]
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
for task in tasks_:
self.curr_tasks.append(
asyncio.create_task(self._task_wrapper(task), name=task.get_name())
)
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task):
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
Args:
task (asyncio.Task): 要执行的异步任务
"""
try:
await task
except asyncio.CancelledError:
pass # 任务被取消, 静默处理
pass
except Exception as e:
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def start(self):
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
self._load()
logger.info("AstrBot 启动完成。")
# 执行启动完成事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAstrBotLoadedEvent
)
for handler in handlers:
try:
logger.info(
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except BaseException:
logger.error(traceback.format_exc())
# 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self):
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
# 请求停止所有正在运行的异步任务
self.event_queue.closed = True
for task in self.curr_tasks:
task.cancel()
for plugin in self.plugin_manager.context.get_all_stars():
try:
await self.plugin_manager._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
)
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
for task in self.curr_tasks:
try:
await task
@@ -243,55 +124,14 @@ class AstrBotCoreLifecycle:
pass
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
async def restart(self):
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def restart(self):
self.event_queue.closed = True
threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start()
def load_platform(self) -> List[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(
platform_inst.run(),
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
)
)
return tasks
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
"""加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
mapping = {}
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
mapping[conf_id] = scheduler
return mapping
async def reload_pipeline_scheduler(self, conf_id: str):
"""重新加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
if not ab_config:
raise ValueError(f"配置文件 {conf_id} 不存在")
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
self.pipeline_scheduler_mapping[conf_id] = scheduler
tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name))
return tasks

View File

@@ -1,284 +1,113 @@
import abc
import datetime
import typing as T
from deprecated import deprecated
from dataclasses import dataclass
from astrbot.core.db.po import (
Stats,
PlatformStat,
ConversationV2,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
)
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@dataclass
class BaseDatabase(abc.ABC):
"""
'''
数据库基类
"""
DATABASE_URL = ""
'''
def __init__(self) -> None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
)
async def initialize(self):
"""初始化数据库连接"""
pass
@asynccontextmanager
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
"""Get a database session."""
if not self.inited:
await self.initialize()
self.inited = True
async with self.AsyncSessionLocal() as session:
yield session
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
def insert_base_metrics(self, metrics: dict):
'''插入基础指标数据'''
self.insert_platform_metrics(metrics['platform_stats'])
self.insert_plugin_metrics(metrics['plugin_stats'])
self.insert_command_metrics(metrics['command_stats'])
self.insert_llm_metrics(metrics['llm_stats'])
@abc.abstractmethod
def insert_platform_metrics(self, metrics: dict):
'''插入平台指标数据'''
raise NotImplementedError
@abc.abstractmethod
def insert_plugin_metrics(self, metrics: dict):
'''插入插件指标数据'''
raise NotImplementedError
@abc.abstractmethod
def insert_command_metrics(self, metrics: dict):
'''插入指令指标数据'''
raise NotImplementedError
@abc.abstractmethod
def insert_llm_metrics(self, metrics: dict):
'''插入 LLM 指标数据'''
raise NotImplementedError
@abc.abstractmethod
def update_llm_history(self, session_id: str, content: str, provider_type: str):
'''更新 LLM 历史记录。当不存在 session_id 时插入'''
raise NotImplementedError
@abc.abstractmethod
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> List[LLMHistory]:
'''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有'''
raise NotImplementedError
@abc.abstractmethod
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据"""
'''获取基础统计数据'''
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_total_message_count(self) -> int:
"""获取总消息数"""
'''获取总消息数'''
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据(合并)"""
'''获取基础统计数据(合并)'''
raise NotImplementedError
# New methods in v4.0.0
@abc.abstractmethod
def insert_atri_vision_data(self, vision_data: ATRIVision):
'''插入 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_atri_vision_data(self) -> List[ATRIVision]:
'''获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
'''通过 user_id 和 cid 获取 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def new_conversation(self, user_id: str, cid: str):
'''新建 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime.datetime | None = None,
) -> None:
"""Insert a new platform statistic record."""
...
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新 Conversation'''
raise NotImplementedError
@abc.abstractmethod
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
...
def delete_conversation(self, user_id: str, cid: str):
'''删除 Conversation'''
raise NotImplementedError
@abc.abstractmethod
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
...
def update_conversation_title(self, user_id: str, cid: str, title: str):
'''更新 Conversation 标题'''
raise NotImplementedError
@abc.abstractmethod
async def get_conversations(
self, user_id: str | None = None, platform_id: str | None = None
) -> list[ConversationV2]:
"""Get all conversations for a specific user and platform_id(optional).
content is not included in the result.
"""
...
@abc.abstractmethod
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
"""Get a specific conversation by its ID."""
...
@abc.abstractmethod
async def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> list[ConversationV2]:
"""Get all conversations with pagination."""
...
@abc.abstractmethod
async def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[ConversationV2], int]:
"""Get conversations filtered by platform IDs and search query."""
...
@abc.abstractmethod
async def create_conversation(
self,
user_id: str,
platform_id: str,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
cid: str | None = None,
created_at: datetime.datetime | None = None,
updated_at: datetime.datetime | None = None,
) -> ConversationV2:
"""Create a new conversation."""
...
@abc.abstractmethod
async def update_conversation(
self,
cid: str,
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
) -> None:
"""Update a conversation's history."""
...
@abc.abstractmethod
async def delete_conversation(self, cid: str) -> None:
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,
platform_id: str,
user_id: str,
content: list[dict],
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
"""Insert a new platform message history record."""
...
@abc.abstractmethod
async def delete_platform_message_offset(
self, platform_id: str, user_id: str, offset_sec: int = 86400
) -> None:
"""Delete platform message history records older than the specified offset."""
...
@abc.abstractmethod
async def get_platform_message_history(
self,
platform_id: str,
user_id: str,
page: int = 1,
page_size: int = 20,
) -> list[PlatformMessageHistory]:
"""Get platform message history for a specific user."""
...
@abc.abstractmethod
async def insert_attachment(
self,
path: str,
type: str,
mime_type: str,
):
"""Insert a new attachment record."""
...
@abc.abstractmethod
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
"""Get an attachment by its ID."""
...
@abc.abstractmethod
async def insert_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""Insert a new persona record."""
...
@abc.abstractmethod
async def get_persona_by_id(self, persona_id: str) -> Persona:
"""Get a persona by its ID."""
...
@abc.abstractmethod
async def get_personas(self) -> list[Persona]:
"""Get all personas for a specific bot."""
...
@abc.abstractmethod
async def update_persona(
self,
persona_id: str,
system_prompt: str | None = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona | None:
"""Update a persona's system prompt or begin dialogs."""
...
@abc.abstractmethod
async def delete_persona(self, persona_id: str) -> None:
"""Delete a persona by its ID."""
...
@abc.abstractmethod
async def insert_preference_or_update(
self, scope: str, scope_id: str, key: str, value: dict
) -> Preference:
"""Insert a new preference record."""
...
@abc.abstractmethod
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
"""Get a preference by scope ID and key."""
...
@abc.abstractmethod
async def get_preferences(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""Get all preferences for a specific scope ID or key."""
...
@abc.abstractmethod
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
"""Remove a preference by scope ID and key."""
...
@abc.abstractmethod
async def clear_preferences(self, scope: str, scope_id: str) -> None:
"""Clear all preferences for a specific scope ID."""
...
# @abc.abstractmethod
# async def insert_llm_message(
# self,
# cid: str,
# role: str,
# content: list,
# tool_calls: list = None,
# tool_call_id: str = None,
# parent_id: str = None,
# ) -> LLMMessage:
# """Insert a new LLM message into the conversation."""
# ...
# @abc.abstractmethod
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation."""
# ...
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
'''更新 Conversation Persona ID'''
raise NotImplementedError

View File

@@ -1,64 +0,0 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger, sp
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
migration_persona_data,
migration_preferences,
)
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
"""
检查是否需要进行数据库迁移
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4则需要进行迁移。
"""
data_v3_exists = os.path.exists(get_astrbot_data_path())
if not data_v3_exists:
return False
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_v4"
)
if migration_done:
return False
return True
async def do_migration_v4(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
astrbot_config: AstrBotConfig,
):
"""
执行数据库迁移
迁移旧的 webchat_conversation 表到新的 conversation 表。
迁移旧的 platform 到新的 platform_stats 表。
"""
if not await check_migration_needed_v4(db_helper):
return
logger.info("开始执行数据库迁移...")
# 执行会话表迁移
await migration_conversation_table(db_helper, platform_id_map)
# 执行人格数据迁移
await migration_persona_data(db_helper, astrbot_config)
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移
await migration_preferences(db_helper, platform_id_map)
# 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map)
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_v4", True)
logger.info("数据库迁移完成。")

View File

@@ -1,338 +0,0 @@
import json
import datetime
from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from .shared_preferences_v3 import sp as sp_v3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
2. 迁移旧的 platform 到新的 platform_stats 表。
"""
def get_platform_id(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_id", old_platform_name)
def get_platform_type(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_type", old_platform_name)
async def migration_conversation_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
platform_id_map, session.platform_name
)
session.platform_id = platform_id # 更新平台名称为新的 ID
conv_v2 = ConversationV2(
user_id=str(session),
content=json.loads(conv.history) if conv.history else [],
platform_id=platform_id,
title=conv.title,
persona_id=conv.persona_id,
conversation_id=conv.cid,
created_at=datetime.datetime.fromtimestamp(conv.created_at),
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
)
dbsession.add(conv_v2)
except Exception as e:
logger.error(
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
async def migration_platform_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
secs_from_2023_4_10_to_now = (
datetime.datetime.now(datetime.timezone.utc)
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
).total_seconds()
offset_sec = int(secs_from_2023_4_10_to_now)
logger.info(f"迁移旧平台数据offset_sec: {offset_sec} 秒。")
stats = db_helper_v3.get_base_stats(offset_sec=offset_sec)
logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...")
platform_stats_v3 = stats.platform
if not platform_stats_v3:
logger.info("没有找到旧平台数据,跳过迁移。")
return
first_time_stamp = platform_stats_v3[0].timestamp
end_time_stamp = platform_stats_v3[-1].timestamp
start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时
end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时
idx = 0
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
total_buckets = (end_time - start_time) // 3600
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
if bucket_idx % 500 == 0:
progress = int((bucket_idx + 1) / total_buckets * 100)
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
cnt = 0
while (
idx < len(platform_stats_v3)
and platform_stats_v3[idx].timestamp < bucket_end
):
cnt += platform_stats_v3[idx].count
idx += 1
if cnt == 0:
continue
platform_id = get_platform_id(
platform_id_map, platform_stats_v3[idx].name
)
platform_type = get_platform_type(
platform_id_map, platform_stats_v3[idx].name
)
try:
await dbsession.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": datetime.datetime.fromtimestamp(
bucket_end, tz=datetime.timezone.utc
),
"platform_id": platform_id,
"platform_type": platform_type,
"count": cnt,
},
)
except Exception:
logger.error(
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
exc_info=True,
)
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
async def migration_webchat_data(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
continue
platform_id = "webchat"
history = json.loads(conv.history) if conv.history else []
for msg in history:
type_ = msg.get("type") # user type, "bot" or "user"
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=conv.cid, # we use conv.cid as user_id for webchat
content=msg,
sender_id=type_,
sender_name=type_,
)
dbsession.add(new_history)
except Exception:
logger.error(
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
async def migration_persona_data(
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
):
"""
迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
"""
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
total_personas = len(v3_persona_config)
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
for idx, persona in enumerate(v3_persona_config):
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
progress = int((idx + 1) / total_personas * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
try:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
mood_prompt = ""
user_turn = True
for mood_dialog in mood_imitation_dialogs:
if user_turn:
mood_prompt += f"A: {mood_dialog}\n"
else:
mood_prompt += f"B: {mood_dialog}\n"
user_turn = not user_turn
system_prompt = persona.get("prompt", "")
if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
persona_new = await db_helper.insert_persona(
persona_id=persona["name"],
system_prompt=system_prompt,
begin_dialogs=begin_dialogs,
)
logger.info(
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
async def migration_preferences(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
# 1. global scope migration
keys = [
"inactivated_llm_tools",
"inactivated_plugins",
"curr_provider",
"curr_provider_tts",
"curr_provider_stt",
"alter_cmd",
]
for key in keys:
value = sp_v3.get(key)
if value is not None:
await sp.put_async("global", "global", key, value)
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
# 2. umo scope migration
session_conversation = sp_v3.get("session_conversation", default={})
for umo, conversation_id in session_conversation.items():
if not umo or not conversation_id:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
session_service_config = sp_v3.get("session_service_config", default={})
for umo, config in session_service_config.items():
if not umo or not config:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_service_config", config)
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
session_variables = sp_v3.get("session_variables", default={})
for umo, variables in session_variables.items():
if not umo or not variables:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_variables", variables)
except Exception as e:
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
session_provider_perf = sp_v3.get("session_provider_perf", default={})
for umo, perf in session_provider_perf.items():
if not umo or not perf:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
for provider_type, provider_id in perf.items():
await sp.put_async(
"umo", str(session), f"provider_perf_{provider_type}", provider_id
)
logger.info(
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
)
except Exception as e:
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)

View File

@@ -1,47 +0,0 @@
import json
import os
from typing import TypeVar
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
sp = SharedPreferences()

View File

@@ -1,494 +0,0 @@
import sqlite3
import time
from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
INIT_SQL = """
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';
"""
class SQLiteDatabase:
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path
sql = INIT_SQL
# 初始化数据库
self.conn = self._get_conn(self.db_path)
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
"""
PRAGMA table_info(webchat_conversation)
"""
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
"""
)
self.conn.commit()
if not has_persona_id:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
"""
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
conn = self._get_conn(self.db_path)
c = conn.cursor()
if params:
c.execute(sql, params)
c.close()
else:
c.execute(sql)
c.close()
conn.commit()
def insert_platform_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def insert_llm_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM platform
"""
+ where_clause
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform=platform)
def get_total_message_count(self) -> int:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT SUM(count) FROM platform
"""
)
res = c.fetchone()
c.close()
return res[0]
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT name, SUM(count), timestamp FROM platform
"""
+ where_clause
+ " GROUP BY name"
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
res = c.fetchone()
c.close()
if not res:
return
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
"""
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
""",
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
""",
(user_id,),
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
)
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新对话,并且同时更新时间"""
updated_at = int(time.time())
self._exec_sql(
"""
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
""",
(history, updated_at, user_id, cid),
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
""",
(title, user_id, cid),
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
""",
(persona_id, user_id, cid),
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
"""
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()

View File

@@ -1,244 +1,70 @@
import uuid
'''指标数据'''
from datetime import datetime, timezone
from dataclasses import dataclass, field
from sqlmodel import (
SQLModel,
Text,
JSON,
UniqueConstraint,
Field,
)
from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
"""This class represents the statistics of bot usage across different platforms.
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
platform_id: str = Field(nullable=False)
platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc.
count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"timestamp",
"platform_id",
"platform_type",
name="uix_platform_stats",
),
)
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
inner_conversation_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conversation_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
content: Optional[list] = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
title: Optional[str] = Field(default=None, max_length=255)
persona_id: Optional[str] = Field(default=None)
__table_args__ = (
UniqueConstraint(
"conversation_id",
name="uix_conversation_id",
),
)
class Persona(SQLModel, table=True):
"""Persona is a set of instructions for LLMs to follow.
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
tools: Optional[list] = Field(default=None, sa_type=JSON)
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"persona_id",
name="uix_persona_id",
),
)
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences"
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
scope: str = Field(nullable=False)
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
scope_id: str = Field(nullable=False)
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
key: str = Field(nullable=False)
value: dict = Field(sa_type=JSON, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"scope",
"scope_id",
"key",
name="uix_preference_scope_scope_id_key",
),
)
class PlatformMessageHistory(SQLModel, table=True):
"""This class represents the message history for a specific platform.
It is used to store messages that are not LLM-generated, such as user messages
or platform-specific messages.
"""
__tablename__ = "platform_message_history"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
sender_name: Optional[str] = Field(
default=None
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class Attachment(SQLModel, table=True):
"""This class represents attachments for messages in AstrBot.
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments"
inner_attachment_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
attachment_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
mime_type: str = Field(nullable=False) # MIME type of the file
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"attachment_id",
name="uix_attachment_id",
),
)
from typing import List
@dataclass
class Conversation:
"""LLM 对话类
对于 WebChathistory 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
在 v4.0.0 版本及之后WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中,
"""
platform_id: str
user_id: str
cid: str
"""对话 ID, 是 uuid 格式的字符串"""
history: str = ""
"""字符串格式的对话列表。"""
title: str | None = ""
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
class Personality(TypedDict):
"""LLM 人格类。
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
# ====
# Deprecated, and will be removed in future versions.
# ====
class Platform():
name: str
count: int
timestamp: int
@dataclass
class Platform:
"""平台使用统计数据"""
class Provider():
name: str
count: int
timestamp: int
@dataclass
class Plugin():
name: str
count: int
timestamp: int
@dataclass
class Command():
name: str
count: int
timestamp: int
@dataclass
class Stats:
platform: list[Platform] = field(default_factory=list)
class Stats():
platform: List[Platform] = field(default_factory=list)
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
@dataclass
class LLMHistory():
'''LLM 聊天时持久化的信息'''
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision():
'''Deprecated'''
id: str
url_or_path: str
caption: str
is_meme: bool
keywords: List[str]
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
@dataclass
class Conversation():
'''LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
'''
user_id: str
cid: str
history: str = ""
'''字符串格式的列表。'''
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""

View File

@@ -1,542 +1,365 @@
import asyncio
import typing as T
import threading
from datetime import datetime, timedelta
from astrbot.core.db import BaseDatabase
import sqlite3
import os
import time
from astrbot.core.db.po import (
ConversationV2,
PlatformStat,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
Stats as DeprecatedStats,
Platform as DeprecatedPlatformStat,
SQLModel,
Platform,
Stats,
LLMHistory,
ATRIVision,
Conversation
)
from sqlalchemy import select, update, delete, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
from . import BaseDatabase
from typing import Tuple
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
super().__init__()
async def initialize(self) -> None:
"""Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.commit()
# ====
# Platform Statistics
# ====
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime = None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
if timestamp is None:
timestamp = datetime.now().replace(
minute=0, second=0, microsecond=0
)
current_hour = timestamp
await session.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": current_hour,
"platform_id": platform_id,
"platform_type": platform_type,
"count": count,
},
)
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
self.db_path = db_path
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
sql = f.read()
# 初始化数据库
self.conn = self._get_conn(self.db_path)
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
'''
PRAGMA table_info(webchat_conversation)
'''
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
'''
)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
text("""
SELECT * FROM platform_stats
WHERE timestamp >= :start_time
ORDER BY timestamp DESC
GROUP BY platform_id
"""),
{"start_time": start_time},
self.conn.commit()
if not has_persona_id:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
'''
)
return result.scalars().all()
# ====
# Conversation Management
# ====
async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
if user_id:
query = query.where(ConversationV2.user_id == user_id)
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(ConversationV2.created_at.desc())
result = await session.execute(query)
return result.scalars().all()
async def get_conversation_by_id(self, cid):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_all_conversations(self, page=1, page_size=20):
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(ConversationV2.created_at.desc())
.offset(offset)
.limit(page_size)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
conn = self._get_conn(self.db_path)
c = conn.cursor()
if params:
c.execute(sql, params)
c.close()
else:
c.execute(sql)
c.close()
conn.commit()
def insert_platform_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
'''
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
''', (k, v, int(time.time()))
)
return result.scalars().all()
async def get_filtered_conversations(
self,
page=1,
page_size=20,
platform_ids=None,
search_query="",
**kwargs,
):
async with self.get_db() as session:
session: AsyncSession
# Build the base query with filters
base_query = select(ConversationV2)
def insert_plugin_metrics(self, metrics: dict):
pass
if platform_ids:
base_query = base_query.where(
ConversationV2.platform_id.in_(platform_ids)
)
if search_query:
base_query = base_query.where(
ConversationV2.title.ilike(f"%{search_query}%")
)
# Get total count matching the filters
count_query = select(func.count()).select_from(base_query.subquery())
total_count = await session.execute(count_query)
total = total_count.scalar_one()
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(ConversationV2.created_at.desc())
.offset(offset)
.limit(page_size)
def insert_command_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
'''
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
''', (k, v, int(time.time()))
)
result = await session.execute(result_query)
conversations = result.scalars().all()
return conversations, total
async def create_conversation(
self,
user_id,
platform_id,
content=None,
title=None,
persona_id=None,
cid=None,
created_at=None,
updated_at=None,
):
kwargs = {}
if cid:
kwargs["conversation_id"] = cid
if created_at:
kwargs["created_at"] = created_at
if updated_at:
kwargs["updated_at"] = updated_at
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_conversation = ConversationV2(
user_id=user_id,
content=content or [],
platform_id=platform_id,
title=title,
persona_id=persona_id,
**kwargs,
)
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
ConversationV2.conversation_id == cid
)
values = {}
if title is not None:
values["title"] = title
if persona_id is not None:
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if not values:
return
query = query.values(**values)
await session.execute(query)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
)
async def insert_platform_message_history(
self,
platform_id,
user_id,
content,
sender_id=None,
sender_name=None,
):
"""Insert a new platform message history record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=user_id,
content=content,
sender_id=sender_id,
sender_name=sender_name,
)
session.add(new_history)
return new_history
async def delete_platform_message_offset(
self, platform_id, user_id, offset_sec=86400
):
"""Delete platform message history records older than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
now = datetime.now()
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
PlatformMessageHistory.created_at < cutoff_time,
)
)
async def get_platform_message_history(
self, platform_id, user_id, page=1, page_size=20
):
"""Get platform message history records."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
query = (
select(PlatformMessageHistory)
.where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(PlatformMessageHistory.created_at.desc())
def insert_llm_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
'''
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
''', (k, v, int(time.time()))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_attachment = Attachment(
path=path,
type=type,
mime_type=mime_type,
)
session.add(new_attachment)
return new_attachment
async def get_attachment_by_id(self, attachment_id):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def insert_persona(
self, persona_id, system_prompt, begin_dialogs=None, tools=None
):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
new_persona = Persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
tools=tools,
)
session.add(new_persona)
return new_persona
async def get_persona_by_id(self, persona_id):
"""Get a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona).where(Persona.persona_id == persona_id)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_personas(self):
"""Get all personas for a specific bot."""
async with self.get_db() as session:
session: AsyncSession
query = select(Persona)
result = await session.execute(query)
return result.scalars().all()
async def update_persona(
self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
):
"""Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(Persona.persona_id == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
if begin_dialogs is not None:
values["begin_dialogs"] = begin_dialogs
if tools is not NOT_GIVEN:
values["tools"] = tools
if not values:
return
query = query.values(**values)
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id):
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(Persona.persona_id == persona_id)
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
result = await session.execute(query)
existing_preference = result.scalar_one_or_none()
if existing_preference:
existing_preference.value = value
else:
new_preference = Preference(
scope=scope, scope_id=scope_id, key=key, value=value
)
session.add(new_preference)
return existing_preference or new_preference
async def get_preference(self, scope, scope_id, key):
"""Get a preference by key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
def update_llm_history(self, session_id: str, content: str, provider_type: str):
res = self.get_llm_history(session_id, provider_type)
if res:
self._exec_sql(
'''
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
''', (content, session_id, provider_type)
)
else:
self._exec_sql(
'''
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
''', (provider_type, session_id, content)
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_preferences(self, scope, scope_id=None, key=None):
"""Get all preferences for a specific scope ID or key."""
async with self.get_db() as session:
session: AsyncSession
query = select(Preference).where(Preference.scope == scope)
if scope_id is not None:
query = query.where(Preference.scope_id == scope_id)
if key is not None:
query = query.where(Preference.key == key)
result = await session.execute(query)
return result.scalars().all()
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
where_clause = ""
if session_id or provider_type:
where_clause += " WHERE "
has = False
if session_id:
where_clause += f"session_id = '{session_id}'"
has = True
if provider_type:
if has:
where_clause += " AND "
where_clause += f"provider_type = '{provider_type}'"
c.execute(
'''
SELECT * FROM llm_history
''' + where_clause
)
res = c.fetchall()
histories = []
for row in res:
histories.append(LLMHistory(*row))
c.close()
return histories
async def remove_preference(self, scope, scope_id, key):
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
)
)
await session.commit()
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
'''获取 offset_sec 秒前到现在的基础统计数据'''
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM platform
''' + where_clause
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
# c.execute(
# '''
# SELECT * FROM command
# ''' + where_clause
# )
# command = []
# for row in c.fetchall():
# command.append(Command(*row))
# c.execute(
# '''
# SELECT * FROM llm
# ''' + where_clause
# )
# llm = []
# for row in c.fetchall():
# llm.append(Provider(*row))
c.close()
return Stats(platform, [], [])
def get_total_message_count(self) -> int:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT SUM(count) FROM platform
'''
)
res = c.fetchone()
c.close()
return res[0]
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
'''获取 offset_sec 秒前到现在的基础统计数据(合并)'''
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT name, SUM(count), timestamp FROM platform
''' + where_clause + " GROUP BY name"
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
res = c.fetchone()
c.close()
if not res:
return
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
''', (user_id, cid, history, updated_at, created_at)
)
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新对话,并且同时更新时间'''
updated_at = int(time.time())
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
''', (history, updated_at, user_id, cid)
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
''', (title, user_id, cid)
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
''', (persona_id, user_id, cid)
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
async def clear_preferences(self, scope, scope_id):
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope, Preference.scope_id == scope_id
)
)
await session.commit()
# ====
# Deprecated Methods
# ====
def get_base_stats(self, offset_sec=86400):
"""Get base statistics within the specified offset in seconds."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat).where(PlatformStat.timestamp >= start_time)
)
all_datas = result.scalars().all()
deprecated_stats = DeprecatedStats()
for data in all_datas:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=data.timestamp.timestamp(),
)
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_total_message_count(self):
"""Get the total message count from platform statistics."""
async def _inner():
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.sum(PlatformStat.count)).select_from(PlatformStat)
)
total_count = result.scalar_one_or_none()
return total_count if total_count is not None else 0
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def get_grouped_base_stats(self, offset_sec=86400):
# group by platform_id
async def _inner():
async with self.get_db() as session:
session: AsyncSession
now = datetime.now()
start_time = now - timedelta(seconds=offset_sec)
result = await session.execute(
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
.where(PlatformStat.timestamp >= start_time)
.group_by(PlatformStat.platform_id)
)
grouped_stats = result.all()
deprecated_stats = DeprecatedStats()
for platform_id, count in grouped_stats:
deprecated_stats.platform.append(
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=start_time.timestamp(),
)
)
return deprecated_stats
result = None
def runner():
nonlocal result
result = asyncio.run(_inner())
t = threading.Thread(target=runner)
t.start()
t.join()
return result
def insert_atri_vision_data(self, vision: ATRIVision):
ts = int(time.time())
keywords = ",".join(vision.keywords)
self._exec_sql(
'''
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts)
)
def get_atri_vision_data(self) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM atri_vision
'''
)
res = c.fetchall()
visions = []
for row in res:
visions.append(ATRIVision(*row))
c.close()
return visions
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
''', (url_or_path, id)
)
res = c.fetchone()
c.close()
if res:
return ATRIVision(*res)
return None

View File

@@ -0,0 +1,48 @@
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);

View File

@@ -1,46 +0,0 @@
import abc
from dataclasses import dataclass
@dataclass
class Result:
similarity: float
data: dict
class BaseVecDB:
async def initialize(self):
"""
初始化向量数据库
"""
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
top_k (int): 返回的最相似文档的数量
Returns:
List[Result]: 查询结果
"""
...
@abc.abstractmethod
async def delete(self, doc_id: str) -> bool:
"""
删除指定文档。
Args:
doc_id (str): 要删除的文档 ID
Returns:
bool: 删除是否成功
"""
...

View File

@@ -1,3 +0,0 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]

View File

@@ -1,121 +0,0 @@
import aiosqlite
import os
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
async def get_documents(self, metadata_filters: dict, ids: list = None):
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
"""
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
await self.connection.commit()
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
Returns:
list: A list of user IDs.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
return [row[0] for row in rows]
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
Args:
row (tuple): The row to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": row[0],
"doc_id": row[1],
"text": row[2],
"metadata": row[3],
"created_at": row[4],
"updated_at": row[5],
}
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None

View File

@@ -1,59 +0,0 @@
try:
import faiss
except ModuleNotFoundError:
raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。"
)
import os
import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str = None):
self.dimension = dimension
self.path = path
self.index = None
if path and os.path.exists(path):
self.index = faiss.read_index(path)
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
Args:
vector (np.ndarray): 要插入的向量
id (int): 向量的ID
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.storage[id] = vector
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
"""搜索最相似的向量
Args:
vector (np.ndarray): 查询向量
k (int): 返回的最相似向量的数量
Returns:
tuple: (距离, 索引)
"""
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def save_index(self):
"""保存索引
Args:
path (str): 保存索引的路径
"""
faiss.write_index(self.index, self.path)

View File

@@ -1,17 +0,0 @@
-- 创建文档存储表,包含 faiss 中文档的 id文档文本create_atupdated_at
CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id TEXT NOT NULL,
text TEXT NOT NULL,
metadata TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE documents
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
ALTER TABLE documents
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
CREATE INDEX idx_documents_user_id ON documents(user_id);
CREATE INDEX idx_documents_group_id ON documents(group_id);

View File

@@ -1,141 +0,0 @@
import uuid
import json
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
class FaissVecDB(BaseVecDB):
"""
A class to represent a vector database.
"""
def __init__(
self,
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
rerank_provider: RerankProvider | None = None,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
self.embedding_provider = embedding_provider
self.document_storage = DocumentStorage(doc_store_path)
self.embedding_storage = EmbeddingStorage(
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def retrieve(
self,
query: str,
k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。
metadata_filters (dict): 元数据过滤器
Returns:
List[Result]: 查询结果
"""
embedding = await self.embedding_provider.get_embedding(query)
scores, indices = await self.embedding_storage.search(
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
scores[0] = 1.0 - (scores[0] / 2.0)
# NOTE: maybe the size is less than k.
fetched_docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters or {}, ids=indices[0]
)
if not fetched_docs:
return []
result_docs: list[Result] = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
pos = idx_pos.get(indice_idx)
if pos is None:
continue
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
top_k_results = result_docs[:k]
if rerank and self.rerank_provider:
documents = [doc.data["text"] for doc in top_k_results]
reranked_results = await self.rerank_provider.rerank(query, documents)
reranked_results = sorted(
reranked_results, key=lambda x: x.relevance_score, reverse=True
)
top_k_results = [
top_k_results[reranked_result.index]
for reranked_result in reranked_results
]
return top_k_results
async def delete(self, doc_id: int):
"""
删除一条文档
"""
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
async def close(self):
await self.document_storage.close()
async def count_documents(self) -> int:
"""
计算文档数量
"""
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0

View File

@@ -1,59 +1,23 @@
"""
事件总线, 用于处理事件的分发和处理
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
class:
EventBus: 事件总线, 用于处理事件的分发和处理
工作流程:
1. 维护一个异步队列, 来接受各种消息事件
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
"""
import asyncio
from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class EventBus:
"""用于处理事件的分发和处理"""
def __init__(
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
):
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue
self.pipeline_scheduler = pipeline_scheduler
async def dispatch(self):
logger.info("事件总线已打开。")
while True:
event: AstrMessageEvent = await self.event_queue.get()
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str):
"""用于记录事件信息
Args:
event (AstrMessageEvent): 事件对象
"""
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
self._print_event(event)
asyncio.create_task(self.pipeline_scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent):
if event.get_sender_name():
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
logger.info(f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}")
else:
logger.info(
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
)
logger.info(f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}")

View File

@@ -1,92 +0,0 @@
import asyncio
import os
import uuid
import time
from urllib.parse import urlparse, unquote
import platform
class FileTokenService:
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
def __init__(self, default_timeout: float = 300):
self.lock = asyncio.Lock()
self.staged_files = {} # token: (file_path, expire_time)
self.default_timeout = default_timeout
async def _cleanup_expired_tokens(self):
"""清理过期的令牌"""
now = time.time()
expired_tokens = [
token for token, (_, expire) in self.staged_files.items() if expire < now
]
for token in expired_tokens:
self.staged_files.pop(token, None)
async def register_file(self, file_path: str, timeout: float = None) -> str:
"""向令牌服务注册一个文件。
Args:
file_path(str): 文件路径
timeout(float): 超时时间,单位秒(可选)
Returns:
str: 一个单次令牌
Raises:
FileNotFoundError: 当路径不存在时抛出
"""
# 处理 file:///
try:
parsed_uri = urlparse(file_path)
if parsed_uri.scheme == "file":
local_path = unquote(parsed_uri.path)
if platform.system() == "Windows" and local_path.startswith("/"):
local_path = local_path[1:]
else:
# 如果没有 file:/// 前缀,则认为是普通路径
local_path = file_path
except Exception:
# 解析失败时,按原路径处理
local_path = file_path
async with self.lock:
await self._cleanup_expired_tokens()
if not os.path.exists(local_path):
raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})"
)
file_token = str(uuid.uuid4())
expire_time = time.time() + (
timeout if timeout is not None else self.default_timeout
)
# 存储转换后的真实路径
self.staged_files[file_token] = (local_path, expire_time)
return file_token
async def handle_file(self, file_token: str) -> str:
"""根据令牌获取文件路径,使用后令牌失效。
Args:
file_token(str): 注册时返回的令牌
Returns:
str: 文件路径
Raises:
KeyError: 当令牌不存在或已过期时抛出
FileNotFoundError: 当文件本身已被删除时抛出
"""
async with self.lock:
await self._cleanup_expired_tokens()
if file_token not in self.staged_files:
raise KeyError(f"无效或过期的文件 token: {file_token}")
file_path, _ = self.staged_files.pop(file_token)
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
return file_path

View File

@@ -1,52 +0,0 @@
"""
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
工作流程:
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
2. 运行核心生命周期任务和仪表板服务器
"""
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard
class InitialLoader:
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
self.db = db
self.logger = logger
self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
try:
await core_lifecycle.initialize()
except Exception as e:
logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
return
core_task = core_lifecycle.start()
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
)
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
try:
await task # 整个AstrBot在这里运行
except asyncio.CancelledError:
logger.info("🌈 正在关闭 AstrBot...")
await core_lifecycle.stop()

View File

@@ -1,119 +1,37 @@
"""
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
class:
LogBroker: 日志代理类, 用于缓存和分发日志消息
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
LogManager: 日志管理器, 用于创建和配置日志记录器
function:
is_plugin_path: 检查文件路径是否来自插件目录
get_short_level_name: 将日志级别名称转换为四个字母的缩写
工作流程:
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
"""
import logging
import colorlog
import asyncio
import os
import sys
from collections import deque
from asyncio import Queue
from typing import List
# 日志缓存大小
CACHED_SIZE = 200
# 日志颜色配置
log_color_config = {
"DEBUG": "green",
"INFO": "bold_cyan",
"WARNING": "bold_yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
"RESET": "reset",
"asctime": "green",
'DEBUG': 'bold_blue', 'INFO': 'bold_cyan',
'WARNING': 'bold_yellow', 'ERROR': 'red',
'CRITICAL': 'bold_red', 'RESET': 'reset',
'asctime': 'green'
}
def is_plugin_path(pathname):
"""检查文件路径是否来自插件目录
Args:
pathname (str): 文件路径
Returns:
bool: 如果路径来自插件目录,则返回 True否则返回 False
"""
if not pathname:
return False
norm_path = os.path.normpath(pathname)
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
def get_short_level_name(level_name):
"""将日志级别名称转换为四个字母的缩写
Args:
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
Returns:
str: 四个字母的日志级别缩写
"""
level_map = {
"DEBUG": "DBUG",
"INFO": "INFO",
"WARNING": "WARN",
"ERROR": "ERRO",
"CRITICAL": "CRIT",
}
return level_map.get(level_name, level_name[:4].upper())
class LogBroker:
"""日志代理类, 用于缓存和分发日志消息
发布-订阅模式
"""
def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: List[Queue] = [] # 订阅者列表
self.log_cache = deque(maxlen=CACHED_SIZE)
self.subscribers: List[Queue] = []
def register(self) -> Queue:
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns:
Queue: 订阅者的队列, 可用于接收日志消息
"""
'''给每个订阅者返回一个带有日志缓存的队列'''
q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
self.subscribers.append(q)
return q
def unregister(self, q: Queue):
"""取消订阅
Args:
q (Queue): 需要取消订阅的队列
"""
'''取消订阅'''
self.subscribers.remove(q)
def publish(self, log_entry: dict):
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
"""
def publish(self, log_entry: str):
'''发布消息'''
self.log_cache.append(log_entry)
for q in self.subscribers:
try:
@@ -121,126 +39,41 @@ class LogBroker:
except asyncio.QueueFull:
pass
class LogQueueHandler(logging.Handler):
"""日志处理器, 用于将日志消息发送到 LogBroker
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker):
super().__init__()
self.log_broker = log_broker
def emit(self, record):
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish(
{
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
}
)
self.log_broker.publish(log_entry)
class LogManager:
"""日志管理器, 用于创建和配置日志记录器
提供了获取默认日志记录器logger和设置队列处理器的方法
"""
@classmethod
def GetLogger(cls, log_name: str = "default"):
"""获取指定名称的日志记录器logger
Args:
log_name (str): 日志记录器的名称, 默认为 "default"
Returns:
logging.Logger: 返回配置好的日志记录器
"""
def GetLogger(cls, log_name: str = 'default'):
logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
if logger.hasHandlers():
return logger
# 如果logger没有处理器
console_handler = logging.StreamHandler(
sys.stdout
) # 创建一个StreamHandler用于控制台输出
console_handler.setLevel(
logging.DEBUG
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
datefmt="%H:%M:%S",
log_colors=log_color_config,
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s',
datefmt='%H:%M:%S',
log_colors=log_color_config
)
class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record):
record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
)
return True
class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
dirname = os.path.dirname(record.pathname)
record.filename = (
os.path.basename(dirname)
+ "."
+ os.path.basename(record.pathname).replace(".py", "")
)
return True
class LevelNameFilter(logging.Filter):
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称
def filter(self, record):
record.short_levelname = get_short_level_name(record.levelname)
return True
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
logger.addFilter(PluginFilter()) # 添加插件过滤器
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
logger.addHandler(console_handler) # 添加处理器到logger
console_handler.setFormatter(console_formatter)
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
"""
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
if logger.handlers:
handler.setFormatter(logger.handlers[0].formatter)
else:
# 为队列处理器设置相同格式的formatter
handler.setFormatter(
logging.Formatter(
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
)
)
logger.addHandler(handler)
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

View File

@@ -1,4 +1,4 @@
"""
'''
MIT License
Copyright (c) 2021 Lxns-Network
@@ -20,37 +20,21 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
'''
import asyncio
import base64
import json
import os
import typing as T
import uuid
from enum import Enum
from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
class ComponentType(str, Enum):
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
Video = "Video" # 视频
At = "At" # At
Node = "Node" # 转发消息的一个节点
Nodes = "Nodes" # 转发消息的多个节点
Poke = "Poke" # QQ 戳一戳
Image = "Image" # 图片
Reply = "Reply" # 回复
Forward = "Forward" # 转发消息
File = "File" # 文件
class ComponentType(Enum):
Plain = "Plain"
Face = "Face"
Record = "Record"
Video = "Video"
At = "At"
RPS = "RPS" # TODO
Dice = "Dice" # TODO
Shake = "Shake" # TODO
@@ -59,14 +43,18 @@ class ComponentType(str, Enum):
Contact = "Contact" # TODO
Location = "Location" # TODO
Music = "Music"
Image = "Image"
Reply = "Reply"
RedBag = "RedBag"
Poke = "Poke"
Forward = "Forward"
Node = "Node"
Xml = "Xml"
Json = "Json"
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
File = "File"
class BaseMessageComponent(BaseModel):
@@ -81,34 +69,29 @@ class BaseMessageComponent(BaseModel):
k = "type"
if isinstance(v, bool):
v = 1 if v else 0
output += ",%s=%s" % (
k,
str(v)
.replace("&", "&amp;")
.replace(",", "&#44;")
.replace("[", "&#91;")
.replace("]", "&#93;"),
)
output += ",%s=%s" % (k, str(v).replace("&", "&amp;") \
.replace(",", "&#44;") \
.replace("[", "&#91;") \
.replace("]", "&#93;"))
output += "]"
return output
def toDict(self):
data = {}
data = dict()
for k, v in self.__dict__.items():
if k == "type" or v is None:
continue
if k == "_type":
k = "type"
data[k] = v
return {"type": self.type.lower(), "data": data}
async def to_dict(self) -> dict:
# 默认情况下,回退到旧的同步 toDict()
return self.toDict()
return {
"type": self.type.lower(),
"data": data
}
class Plain(BaseMessageComponent):
type = ComponentType.Plain
type: ComponentType = "Plain"
text: str
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
@@ -118,19 +101,13 @@ class Plain(BaseMessageComponent):
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
if not self.convert:
return self.text
return (
self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
)
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
return self.text.replace("&", "&amp;") \
.replace("[", "&#91;") \
.replace("]", "&#93;")
class Face(BaseMessageComponent):
type = ComponentType.Face
type: ComponentType = "Face"
id: int
def __init__(self, **_):
@@ -138,7 +115,7 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent):
type = ComponentType.Record
type: ComponentType = "Record"
file: T.Optional[str] = ""
magic: T.Optional[bool] = False
url: T.Optional[str] = ""
@@ -165,85 +142,9 @@ class Record(BaseMessageComponent):
return Record(file=url, **_)
raise Exception("not a valid url")
@staticmethod
def fromBase64(bs64_data: str, **_):
return Record(file=f"base64://{bs64_data}", **_)
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
return self.file[8:]
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
return os.path.abspath(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Returns:
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
async def register_to_file_service(self) -> str:
"""
将语音注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
class Video(BaseMessageComponent):
type = ComponentType.Video
type: ComponentType = "Video"
file: str
cover: T.Optional[str] = ""
c: T.Optional[int] = 2
@@ -251,6 +152,9 @@ class Video(BaseMessageComponent):
path: T.Optional[str] = ""
def __init__(self, file: str, **_):
# for k in _.keys():
# if k == "c" and _[k] not in [2, 3]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -263,85 +167,15 @@ class Video(BaseMessageComponent):
return Video(file=url, **_)
raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL则会自动进行下载
Returns:
str: 视频的本地路径,以绝对路径表示。
"""
url = self.file
if url and url.startswith("file:///"):
return url[8:]
elif url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
else:
raise Exception(f"download failed: {url}")
elif os.path.exists(url):
return os.path.abspath(url)
else:
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
"""
将视频注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = self.file
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated video file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "video",
"data": {
"file": payload_file,
},
}
class At(BaseMessageComponent):
type = ComponentType.At
type: ComponentType = "At"
qq: T.Union[int, str] # 此处str为all时代表所有人
name: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
def toDict(self):
return {
"type": "at",
"data": {"qq": str(self.qq)},
}
class AtAll(At):
qq: str = "all"
@@ -351,28 +185,28 @@ class AtAll(At):
class RPS(BaseMessageComponent): # TODO
type = ComponentType.RPS
type: ComponentType = "RPS"
def __init__(self, **_):
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
type = ComponentType.Dice
type: ComponentType = "Dice"
def __init__(self, **_):
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
type = ComponentType.Shake
type: ComponentType = "Shake"
def __init__(self, **_):
super().__init__(**_)
class Anonymous(BaseMessageComponent): # TODO
type = ComponentType.Anonymous
type: ComponentType = "Anonymous"
ignore: T.Optional[bool] = False
def __init__(self, **_):
@@ -380,7 +214,7 @@ class Anonymous(BaseMessageComponent): # TODO
class Share(BaseMessageComponent):
type = ComponentType.Share
type: ComponentType = "Share"
url: str
title: str
content: T.Optional[str] = ""
@@ -391,7 +225,7 @@ class Share(BaseMessageComponent):
class Contact(BaseMessageComponent): # TODO
type = ComponentType.Contact
type: ComponentType = "Contact"
_type: str # type 字段冲突
id: T.Optional[int] = 0
@@ -400,7 +234,7 @@ class Contact(BaseMessageComponent): # TODO
class Location(BaseMessageComponent): # TODO
type = ComponentType.Location
type: ComponentType = "Location"
lat: float
lon: float
title: T.Optional[str] = ""
@@ -411,7 +245,7 @@ class Location(BaseMessageComponent): # TODO
class Music(BaseMessageComponent):
type = ComponentType.Music
type: ComponentType = "Music"
_type: str
id: T.Optional[int] = 0
url: T.Optional[str] = ""
@@ -428,7 +262,7 @@ class Music(BaseMessageComponent):
class Image(BaseMessageComponent):
type = ComponentType.Image
type: ComponentType = "Image"
file: T.Optional[str] = ""
_type: T.Optional[str] = ""
subType: T.Optional[int] = 0
@@ -438,9 +272,13 @@ class Image(BaseMessageComponent):
c: T.Optional[int] = 2
# 额外
path: T.Optional[str] = ""
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: T.Optional[str], **_):
# for k in _.keys():
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
# (k == "c" and _[k] not in [2, 3]):
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -465,109 +303,21 @@ class Image(BaseMessageComponent):
def fromIO(IO):
return Image.fromBytes(IO.read())
async def convert_to_file_path(self) -> str:
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
return url[8:]
elif url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
return os.path.abspath(url)
else:
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str:
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
Returns:
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
else:
raise Exception(f"not a valid file: {url}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
async def register_to_file_service(self) -> str:
"""
将图片注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
class Reply(BaseMessageComponent):
type = ComponentType.Reply
type: ComponentType = "Reply"
id: T.Union[str, int]
"""所引用的消息 ID"""
chain: T.Optional[T.List["BaseMessageComponent"]] = []
"""被引用的消息段列表"""
sender_id: T.Optional[int] | T.Optional[str] = 0
"""被引用的消息对应的发送者的 ID"""
sender_nickname: T.Optional[str] = ""
"""被引用的消息对应的发送者的昵称"""
time: T.Optional[int] = 0
"""被引用的消息发送时间"""
message_str: T.Optional[str] = ""
"""被引用的消息解析后的纯文本消息字符串"""
text: T.Optional[str] = ""
"""deprecated"""
qq: T.Optional[int] = 0
"""deprecated"""
time: T.Optional[int] = 0
seq: T.Optional[int] = 0
"""deprecated"""
def __init__(self, **_):
super().__init__(**_)
class RedBag(BaseMessageComponent):
type = ComponentType.RedBag
type: ComponentType = "RedBag"
title: str
def __init__(self, **_):
@@ -575,7 +325,7 @@ class RedBag(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
type: str = ""
id: T.Optional[int] = 0
qq: T.Optional[int] = 0
@@ -585,104 +335,46 @@ class Poke(BaseMessageComponent):
class Forward(BaseMessageComponent):
type = ComponentType.Forward
type: ComponentType = "Forward"
id: str
def __init__(self, **_):
super().__init__(**_)
class Node(BaseMessageComponent):
"""群合并转发消息"""
'''群合并转发消息'''
type: ComponentType = "Node"
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[int] = 0 # qq号
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0
type = ComponentType.Node
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
content: T.Optional[list[BaseMessageComponent]] = []
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0 # 忽略
def __init__(self, content: list[BaseMessageComponent], **_):
if isinstance(content, Node):
# back
content = [content]
def __init__(self, content: T.Union[str, list], **_):
if isinstance(content, list):
_content = ""
for chain in content:
_content += chain.toString()
content = _content
super().__init__(content=content, **_)
async def to_dict(self):
data_content = []
for comp in self.content:
if isinstance(comp, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await comp.convert_to_base64()
data_content.append(
{
"type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"},
}
)
elif isinstance(comp, Plain):
# For Plain segments, we need to handle the plain differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, File):
# For File segments, we need to handle the file differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, (Node, Nodes)):
# For Node segments, we recursively convert them to dict
d = await comp.to_dict()
data_content.append(d)
else:
d = comp.toDict()
data_content.append(d)
return {
"type": "node",
"data": {
"user_id": str(self.uin),
"nickname": self.name,
"content": data_content,
},
}
class Nodes(BaseMessageComponent):
type = ComponentType.Nodes
nodes: T.List[Node]
def __init__(self, nodes: T.List[Node], **_):
super().__init__(nodes=nodes, **_)
def toDict(self):
"""Deprecated. Use to_dict instead"""
ret = {
"messages": [],
}
for node in self.nodes:
d = node.toDict()
ret["messages"].append(d)
return ret
async def to_dict(self):
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
d = await node.to_dict()
ret["messages"].append(d)
return ret
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
class Xml(BaseMessageComponent):
type = ComponentType.Xml
type: ComponentType = "Xml"
data: str
resid: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
class Json(BaseMessageComponent):
type = ComponentType.Json
type: ComponentType = "Json"
data: T.Union[str, dict]
resid: T.Optional[int] = 0
@@ -693,7 +385,7 @@ class Json(BaseMessageComponent):
class CardImage(BaseMessageComponent):
type = ComponentType.CardImage
type: ComponentType = "CardImage"
file: str
cache: T.Optional[bool] = True
minwidth: T.Optional[int] = 400
@@ -712,7 +404,7 @@ class CardImage(BaseMessageComponent):
class TTS(BaseMessageComponent):
type = ComponentType.TTS
type: ComponentType = "TTS"
text: str
def __init__(self, **_):
@@ -720,155 +412,22 @@ class TTS(BaseMessageComponent):
class Unknown(BaseMessageComponent):
type = ComponentType.Unknown
type: ComponentType = "Unknown"
text: str
def toString(self):
return ""
class File(BaseMessageComponent):
"""
文件消息段
"""
type = ComponentType.File
name: T.Optional[str] = "" # 名字
file_: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
def __init__(self, name: str, file: str = "", url: str = ""):
"""文件消息段。"""
super().__init__(name=name, file_=file, url=url)
@property
def file(self) -> str:
"""
获取文件路径如果文件不存在但有URL则同步下载文件
Returns:
str: 文件路径
"""
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段"
)
)
return ""
else:
# 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@file.setter
def file(self, value: str):
"""
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args:
value (str): 文件路径或URL
"""
if value.startswith("http://") or value.startswith("https://"):
self.url = value
else:
self.file_ = value
async def get_file(self, allow_return_url: bool = False) -> str:
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
Args:
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
注意,如果为 True也可能返回文件路径。
Returns:
str: 文件路径或者 http 下载链接
"""
if allow_return_url and self.url:
return self.url
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self):
"""
将文件注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.get_file()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = await self.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "file",
"data": {
"name": self.name,
"file": payload_file,
},
}
class WechatEmoji(BaseMessageComponent):
type = ComponentType.WechatEmoji
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url本地路径
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
@@ -892,12 +451,10 @@ ComponentTypes = {
"poke": Poke,
"forward": Forward,
"node": Node,
"nodes": Nodes,
"xml": Xml,
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown,
"file": File,
"WechatEmoji": WechatEmoji,
'file': File,
}

View File

@@ -1,233 +1,149 @@
import enum
from typing import List, Optional, Union, AsyncGenerator
from typing import List, Optional
from dataclasses import dataclass, field
from astrbot.core.message.components import (
BaseMessageComponent,
Plain,
Image,
At,
AtAll,
)
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
from typing_extensions import deprecated
@dataclass
class MessageChain:
"""MessageChain 描述了一整条消息中带有的所有组件。
class MessageChain():
'''MessageChain 描述了一整条消息中带有的所有组件。
现代消息平台的一条富文本消息中可能由多个组件构成如文本、图片、At 等,并且保留了顺序。
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
"""
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
type: Optional[str] = None
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
use_t2i_: Optional[bool] = None # None 为跟随用户设置
def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。
'''添加一条文本消息到消息链 `chain` 中。
Example:
CommandResult().message("Hello ").message("world!")
# 输出 Hello world!
"""
'''
self.chain.append(Plain(message))
return self
def at(self, name: str, qq: Union[str, int]):
"""添加一条 At 消息到消息链 `chain` 中。
Example:
CommandResult().at("张三", "12345678910")
# 输出 @张三
"""
self.chain.append(At(name=name, qq=qq))
return self
def at_all(self):
"""添加一条 AtAll 消息到消息链 `chain` 中。
Example:
CommandResult().at_all()
# 输出 @所有人
"""
self.chain.append(AtAll())
return self
@deprecated("请使用 message 方法代替。")
def error(self, message: str):
"""添加一条错误消息到消息链 `chain` 中
'''添加一条错误消息到消息链 `chain` 中
Example:
CommandResult().error("解析失败")
"""
'''
self.chain.append(Plain(message))
return self
def url_image(self, url: str):
"""添加一条图片消息https 链接)到消息链 `chain` 中。
'''添加一条图片消息https 链接)到消息链 `chain` 中。
Note:
如果需要发送本地图片,请使用 `file_image` 方法。
Example:
CommandResult().image("https://example.com/image.jpg")
"""
'''
self.chain.append(Image.fromURL(url))
return self
def file_image(self, path: str):
"""添加一条图片消息(本地文件路径)到消息链 `chain` 中。
'''添加一条图片消息(本地文件路径)到消息链 `chain` 中。
Note:
如果需要发送网络图片,请使用 `url_image` 方法。
CommandResult().image("image.jpg")
"""
'''
self.chain.append(Image.fromFileSystem(path))
return self
def base64_image(self, base64_str: str):
"""添加一条图片消息base64 编码字符串)到消息链 `chain` 中。
Example:
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
"""
self.chain.append(Image.fromBase64(base64_str))
return self
def use_t2i(self, use_t2i: bool):
"""设置是否使用文本转图片服务。
'''设置是否使用文本转图片服务。
Args:
use_t2i (bool): 是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
"""
'''
self.use_t2i_ = use_t2i
return self
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
def squash_plain(self):
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
if not self.chain:
return
new_chain = []
first_plain = None
plain_texts = []
for comp in self.chain:
if isinstance(comp, Plain):
if first_plain is None:
first_plain = comp
new_chain.append(comp)
plain_texts.append(comp.text)
else:
new_chain.append(comp)
if first_plain is not None:
first_plain.text = "".join(plain_texts)
self.chain = new_chain
return self
class EventResultType(enum.Enum):
"""用于描述事件处理的结果类型。
'''用于描述事件处理的结果类型。
Attributes:
CONTINUE: 事件将会继续传播
STOP: 事件将会终止传播
"""
'''
CONTINUE = enum.auto()
STOP = enum.auto()
class ResultContentType(enum.Enum):
"""用于描述事件结果的内容的类型。"""
'''用于描述事件结果的内容的类型。
'''
LLM_RESULT = enum.auto()
"""调用 LLM 产生的结果"""
'''调用 LLM 产生的结果'''
GENERAL_RESULT = enum.auto()
"""普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
STREAMING_FINISH = enum.auto()
"""流式输出完成"""
'''普通的消息结果'''
@dataclass
class MessageEventResult(MessageChain):
"""MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
现代消息平台的一条富文本消息中可能由多个组件构成如文本、图片、At 等,并且保留了顺序。
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`result_type` (EventResultType): 事件处理的结果类型。
"""
result_type: Optional[EventResultType] = field(
default_factory=lambda: EventResultType.CONTINUE
)
result_content_type: Optional[ResultContentType] = field(
default_factory=lambda: ResultContentType.GENERAL_RESULT
)
async_stream: Optional[AsyncGenerator] = None
"""异步流"""
def stop_event(self) -> "MessageEventResult":
"""终止事件传播。"""
'''
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
def stop_event(self) -> 'MessageEventResult':
'''终止事件传播。
'''
self.result_type = EventResultType.STOP
return self
def continue_event(self) -> "MessageEventResult":
"""继续事件传播。"""
def continue_event(self) -> 'MessageEventResult':
'''继续事件传播。
'''
self.result_type = EventResultType.CONTINUE
return self
def is_stopped(self) -> bool:
"""
'''
是否终止事件传播。
"""
'''
return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
"""设置异步流。"""
self.async_stream = stream
return self
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
"""设置事件处理的结果类型。
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
result_type (EventResultType): 事件处理的结果类型。
"""
'''
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
"""是否为 LLM 结果。"""
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
# 为了兼容旧版代码,保留 CommandResult 的别名
CommandResult = MessageEventResult
def get_plain_text(self) -> str:
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
'''
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
CommandResult = MessageEventResult

View File

@@ -1,183 +0,0 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, Personality
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.platform.message_session import MessageSession
from astrbot import logger
DEFAULT_PERSONALITY = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
begin_dialogs=[],
mood_imitation_dialogs=[],
tools=None,
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed="",
)
class PersonaManager:
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
self.db = db_helper
self.acm = acm
default_ps = acm.default_conf.get("provider_settings", {})
self.default_persona: str = default_ps.get("default_personality", "default")
self.personas: list[Persona] = []
self.selected_default_persona: Persona | None = None
self.personas_v3: list[Personality] = []
self.selected_default_persona_v3: Personality | None = None
self.persona_v3_config: list[dict] = []
async def initialize(self):
self.personas = await self.get_all_personas()
self.get_v3_persona_data()
logger.info(f"已加载 {len(self.personas)} 个人格。")
async def get_persona(self, persona_id: str):
"""获取指定 persona 的信息"""
persona = await self.db.get_persona_by_id(persona_id)
if not persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona
async def get_default_persona_v3(
self, umo: str | MessageSession | None = None
) -> Personality:
"""获取默认 persona"""
cfg = self.acm.get_conf(umo)
default_persona_id = cfg.get("provider_settings", {}).get(
"default_personality", "default"
)
if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
try:
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
except Exception:
return DEFAULT_PERSONALITY
async def delete_persona(self, persona_id: str):
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
await self.db.delete_persona(persona_id)
self.personas = [p for p in self.personas if p.persona_id != persona_id]
self.get_v3_persona_data()
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
if not existing_persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
persona = await self.db.update_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
if persona:
for i, p in enumerate(self.personas):
if p.persona_id == persona_id:
self.personas[i] = persona
break
self.get_v3_persona_data()
return persona
async def get_all_personas(self) -> list[Persona]:
"""获取所有 personas"""
return await self.db.get_personas()
async def create_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} already exists.")
new_persona = await self.db.insert_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
self.personas.append(new_persona)
self.get_v3_persona_data()
return new_persona
def get_v3_persona_data(
self,
) -> tuple[list[dict], list[Personality], Personality]:
"""获取 AstrBot <4.0.0 版本的 persona 数据。
Returns:
- list[dict]: 包含 persona 配置的字典列表。
- list[Personality]: 包含 Personality 对象的列表。
- Personality: 默认选择的 Personality 对象。
"""
v3_persona_config = [
{
"prompt": persona.system_prompt,
"name": persona.persona_id,
"begin_dialogs": persona.begin_dialogs or [],
"mood_imitation_dialogs": [], # deprecated
"tools": persona.tools,
}
for persona in self.personas
]
personas_v3: list[Personality] = []
selected_default_persona: Personality | None = None
for persona_cfg in v3_persona_config:
begin_dialogs = persona_cfg.get("begin_dialogs", [])
bd_processed = []
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append(
{
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
}
)
user_turn = not user_turn
try:
persona = Personality(
**persona_cfg,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed="", # deprecated
)
if persona["name"] == self.default_persona:
selected_default_persona = persona
personas_v3.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not selected_default_persona and len(personas_v3) > 0:
# 默认选择第一个
selected_default_persona = personas_v3[0]
if not selected_default_persona:
selected_default_persona = DEFAULT_PERSONALITY
personas_v3.append(selected_default_persona)
self.personas_v3 = personas_v3
self.selected_default_persona_v3 = selected_default_persona
self.persona_v3_config = v3_persona_config
self.selected_default_persona = Persona(
persona_id=selected_default_persona["name"],
system_prompt=selected_default_persona["prompt"],
begin_dialogs=selected_default_persona["begin_dialogs"],
tools=selected_default_persona["tools"] or None,
)
return v3_persona_config, personas_v3, selected_default_persona

View File

@@ -1,35 +1,28 @@
from astrbot.core.message.message_event_result import (
EventResultType,
MessageEventResult,
)
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .respond.stage import RespondStage
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
]
__all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
@@ -37,5 +30,5 @@ __all__ = [
"ResultDecorateStage",
"RespondStage",
"MessageEventResult",
"EventResultType",
]
"EventResultType"
]

View File

@@ -6,32 +6,26 @@ from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from .strategies.strategy import StrategySelector
@register_stage
class ContentSafetyCheckStage(Stage):
"""检查内容安全
'''检查内容安全
当前只会检查文本的。
"""
'''
async def initialize(self, ctx: PipelineContext):
config = ctx.astrbot_config["content_safety"]
config = ctx.astrbot_config['content_safety']
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]:
'''检查内容安全'''
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
if not ok:
if event.is_at_or_wake_command:
event.set_result(
MessageEventResult().message(
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"
)
)
yield
event.set_result(MessageEventResult().message("你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"))
yield
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return
event.continue_event()

View File

@@ -1,8 +1,8 @@
import abc
from typing import Tuple
class ContentSafetyStrategy(abc.ABC):
@abc.abstractmethod
def check(self, content: str) -> Tuple[bool, str]:
raise NotImplementedError
raise NotImplementedError

View File

@@ -1,30 +1,30 @@
"""
'''
使用此功能应该先 pip install baidu-aip
"""
'''
from . import ContentSafetyStrategy
from aip import AipContentCensor
class BaiduAipStrategy(ContentSafetyStrategy):
def __init__(self, appid: str, ak: str, sk: str) -> None:
self.app_id = appid
self.api_key = ak
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
self.client = AipContentCensor(self.app_id,
self.api_key,
self.secret_key)
def check(self, content: str):
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
if 'conclusionType' not in res:
return False, ""
if res["conclusionType"] == 1:
if res['conclusionType'] == 1:
return True, ""
else:
if "data" not in res:
if 'data' not in res:
return False, ""
count = len(res["data"])
count = len(res['data'])
info = f"百度审核服务发现 {count} 处违规:\n"
for i in res["data"]:
for i in res['data']:
info += f"{i['msg']}\n"
info += "\n判断结果:" + res["conclusion"]
return False, info
info += "\n判断结果:"+res['conclusion']
return False, info

View File

@@ -1,23 +1,23 @@
import re
import os
import json
import base64
from . import ContentSafetyStrategy
class KeywordsStrategy(ContentSafetyStrategy):
def __init__(self, extra_keywords: list) -> None:
self.keywords = []
if extra_keywords is None:
extra_keywords = []
self.keywords.extend(extra_keywords)
# keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words")
keywords_path = os.path.join(os.path.dirname(__file__), 'unfit_words')
# internal keywords
# if os.path.exists(keywords_path):
# with open(keywords_path, "r", encoding="utf-8") as f:
# self.keywords.extend(
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
if os.path.exists(keywords_path):
with open(keywords_path, "r", encoding="utf-8") as f:
self.keywords.extend(json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords'])
def check(self, content: str) -> bool:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"
return True, ""
return True, ""

View File

@@ -2,7 +2,6 @@ from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger
class StrategySelector:
def __init__(self, config: dict) -> None:
self.enabled_strategies: List[ContentSafetyStrategy] = []

View File

@@ -0,0 +1 @@
ewogICAgImtleXdvcmRzIjogWwogICAgICAgICLkuaDov5HlubMiLAogICAgICAgICLog6HplKbmtpsiLAogICAgICAgICLmsZ/ms73msJEiLAogICAgICAgICLmuKnlrrblrp0iLAogICAgICAgICLmnY7lhYvlvLoiLAogICAgICAgICLmnY7plb/mmKUiLAogICAgICAgICLmr5vms73kuJwiLAogICAgICAgICLpgpPlsI/lubMiLAogICAgICAgICLlkajmganmnaUiLAogICAgICAgICLnpL7kvJrkuLvkuYkiLAogICAgICAgICLlhbHkuqflhZoiLAogICAgICAgICLlhbHkuqfkuLvkuYkiLAogICAgICAgICLlpKfpmYblrpjmlrkiLAogICAgICAgICLljJfkuqzmlL/mnYMiLAogICAgICAgICLkuK3ljY7luJ3lm70iLAogICAgICAgICLkuK3lm73mlL/lupwiLAogICAgICAgICLlhbHni5ciLAogICAgICAgICLlha3lm5vkuovku7YiLAogICAgICAgICLlpKnlronpl6giLAogICAgICAgICLlha3lm5siLAogICAgICAgICLmlL/msrvlsYDluLjlp5QiLAogICAgICAgICLlrabmva4iLAogICAgICAgICLlhavkuZ0iLAogICAgICAgICLkuozljYHlpKciLAogICAgICAgICLmsJHov5vlhZoiLAogICAgICAgICLlj7Dni6wiLAogICAgICAgICLlj7Dmub7ni6znq4siLAogICAgICAgICLlj7Dmub7lm70iLAogICAgICAgICLlm73msJHlhZoiLAogICAgICAgICLlj7Dmub7msJHlm70iLAogICAgICAgICLkuK3ljY7msJHlm70iLAogICAgICAgICJwb3JuaHViIiwKICAgICAgICAiUG9ybmh1YiIsCiAgICAgICAgIuS9nOeIsSIsCiAgICAgICAgIuWBmueIsSIsCiAgICAgICAgIuaAp+S6pCIsCiAgICAgICAgIuiHquaFsCIsCiAgICAgICAgIumYtOiMjiIsCiAgICAgICAgIua3q+WmhyIsCiAgICAgICAgIuiCm+S6pCIsCiAgICAgICAgIuS6pOmFjSIsCiAgICAgICAgIuaAp+WFs+ezuyIsCiAgICAgICAgIuaAp+a0u+WKqCIsCiAgICAgICAgIuiJsuaDhSIsCiAgICAgICAgIuiJsuWbviIsCiAgICAgICAgIuijuOS9kyIsCiAgICAgICAgIuWwj+eptCIsCiAgICAgICAgIua3q+iNoSIsCiAgICAgICAgIuaAp+eIsSIsCiAgICAgICAgIua4r+eLrCIsCiAgICAgICAgIuazlei9ruWKnyIsCiAgICAgICAgIuWFreWbmyIKICAgIF0KfQ==

View File

@@ -1,15 +1,8 @@
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star import PluginManager
from .context_utils import call_handler, call_event_hook
@dataclass
class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
astrbot_config: AstrBotConfig
plugin_manager: PluginManager

View File

@@ -1,98 +0,0 @@
import inspect
import traceback
import typing as T
from astrbot import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
async def call_event_hook(
event: AstrMessageEvent,
hook_type: EventType,
*args,
**kwargs,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
#"""
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args, **kwargs)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()

View File

@@ -7,67 +7,64 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Record, Image
@register_stage
class PreProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
self.platform_settings: dict = self.config.get('platform_settings', {})
self.stt_settings: dict = self.config.get("provider_stt_settings", {})
self.platform_settings: dict = self.config.get("platform_settings", {})
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""在处理事件之前的预处理"""
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''在处理事件之前的预处理'''
# 路径映射
if mappings := self.platform_settings.get("path_mapping", []):
if mappings := self.platform_settings.get('path_mapping', []):
# 支持 RecordImage 消息段的路径映射。
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, (Record, Image)) and component.url:
for mapping in mappings:
from_, to_ = mapping.split(":")
from_ = from_.removesuffix("/")
to_ = to_.removesuffix("/")
url = component.url.removeprefix("file://")
if url.startswith(from_):
component.url = url.replace(from_, to_, 1)
logger.debug(f"路径映射: {url} -> {component.url}")
message_chain[idx] = component
# STT
if self.stt_settings.get("enable", False):
if self.stt_settings.get('enable', False):
# TODO: 独立
ctx = self.plugin_manager.context
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break

View File

@@ -1,594 +1,159 @@
"""
'''
本地 Agent 模式的 LLM 调用 Stage
"""
import asyncio
import copy
import json
'''
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_handler
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.star_handler import star_map
from astrbot.core.astr_agent_context import AstrAgentContext
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
if tool.origin == "local":
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
elif tool.origin == "mcp":
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
raise Exception(f"Unknown function origin: {tool.origin}")
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description,
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
)
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await run_context.event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx, event=run_context.event
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
streaming=run_context.context.streaming,
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not run_context.event:
raise ValueError("Event must be provided for local function tools.")
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run"):
raise ValueError("Tool must have a valid handler or 'run' method.")
awaitable = tool.handler or getattr(tool, "run")
wrapper = call_handler(
event=run_context.event,
handler=awaitable,
**tool_args,
)
async for resp in wrapper:
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(
name=tool.name,
arguments=tool_args,
)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.event, EventType.OnLLMResponseEvent, llm_response
)
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
astr_event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
self.provider_wake_prefix: str = settings["wake_prefix"] # str
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。"
)
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
if sel_provider and isinstance(sel_provider, str):
provider = _ctx.get_provider_by_id(sel_provider)
if not provider:
logger.error(f"未找到指定的提供商: {sel_provider}")
return provider
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent):
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
# 获取对话上下文
cid = await conv_mgr.get_curr_conversation_id(umo)
if not cid:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
return conversation
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider is None:
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
req.prompt = event.message_str[len(self.provider_wake_prefix):]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
req.session_id = event.unified_msg_origin
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
try:
logger.debug(f"提供商请求 Payload: {req}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(req.contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
req.contexts = req.contexts[index:]
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# fix messages
req.contexts = self.fix_messages(req.contexts)
# check provider modalities
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
req.image_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
)
req.func_tool = None
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
)
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
)
if self.streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use)
)
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
else:
chain = final_llm_resp.result_chain.chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
)
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
yield
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
if conversation and not req.conversation.title:
messages = json.loads(conversation.history)
latest_pair = messages[-2:]
if not latest_pair:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
prompt=(
f"Please summarize the following query of user:\n"
f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
),
)
if llm_resp and llm_resp.completion_text:
logger.debug(
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
)
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
unified_msg_origin=event.unified_msg_origin,
title=title,
conversation_id=req.conversation.cid,
)
async def _save_to_history(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse | None,
):
if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
return
if not llm_response.completion_text and not req.tool_calls_result:
logger.debug("LLM 响应为空,不保存记录。")
return
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
messages.append(await req.assemble_context())
# 这一轮对话的 LLM 响应
if req.tool_calls_result:
if not isinstance(req.tool_calls_result, list):
messages.extend(req.tool_calls_result.to_openai_messages())
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=messages
)
def fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'err':
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"))
elif llm_response.role == 'tool':
# function calling
function_calling_result = {}
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
# 尝试调用工具函数
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
async for resp in wrapper:
if resp is not None: # 有 return 返回
function_calling_result[func_tool_name] = resp
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
if function_calling_result:
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
# 我们重新执行一遍这个 stage
req.func_tool = None # 暂时不支持递归工具调用
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
for tool_name, tool_result in function_calling_result.items():
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
req.prompt += extra_prompt
async for _ in self.process(event, _nested=True):
yield
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
if llm_response.completion_text:
event.set_result(MessageEventResult().message(llm_response.completion_text))
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
return
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {
"role": "user",
"content": req.prompt
}
contexts.append(new_record)
contexts.append({
"role": "assistant",
"content": llm_response.completion_text
})
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=contexts_to_save
)

View File

@@ -1,8 +1,7 @@
"""
'''
本地 Agent 模式的 AstrBot 插件调用 Stage
"""
from ...context import PipelineContext, call_handler
'''
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -12,46 +11,39 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.star.star import star_map
import traceback
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
"activated_handlers"
)
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra(
"handlers_parsed_params"
)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra("handlers_parsed_params")
if not handlers_parsed_params:
handlers_parsed_params = {}
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
wrapper = call_handler(event, handler.handler, **params)
logger.debug(f"执行插件 handler {handler.handler_full_name}")
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果
event.clear_result() # 清除上一个 handler 的结果
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
event.stop_event()
event.stop_event()

View File

@@ -5,29 +5,26 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core import logger
@register_stage
class ProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.llm_request_sub_stage = LLMRequestSubStage()
await self.llm_request_sub_stage.initialize(ctx)
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件"""
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
"activated_handlers"
)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
# 有插件 Handler 被激活
if activated_handlers:
async for resp in self.star_request_sub_stage.process(event):
@@ -43,26 +40,20 @@ class ProcessStage(Stage):
yield
else:
yield
# 调用 LLM 相关请求
if not self.ctx.astrbot_config["provider_settings"].get("enable", True):
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if (
not event._has_send_oper
and event.is_at_or_wake_command
and not event.call_llm
):
if not event._has_send_oper and event.is_at_or_wake_command:
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
event.get_result() and not event.get_result().is_stopped()
) or not event.get_result():
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
async for _ in self.llm_request_sub_stage.process(event):
yield
yield

View File

@@ -5,6 +5,7 @@ from typing import DefaultDict, Deque, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.config.astrbot_config import RateLimitStrategy
@@ -31,19 +32,11 @@ class RateLimitStage(Stage):
"""
初始化限流器,根据配置设置限流参数。
"""
self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
"count"
]
self.rate_limit_time = timedelta(
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"]
)
self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
"strategy"
] # stall or discard
self.rate_limit_count = ctx.astrbot_config['platform_settings']['rate_limit']['count']
self.rate_limit_time = timedelta(seconds=ctx.astrbot_config['platform_settings']['rate_limit']['time'])
self.rl_strategy = ctx.astrbot_config['platform_settings']['rate_limit']['strategy'] # stall or discard
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
"""
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
@@ -58,34 +51,31 @@ class RateLimitStage(Stage):
now = datetime.now()
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
# 检查并处理限流,可能需要多次检查直到满足条件
while True:
timestamps = self.event_timestamps[session_id]
self._remove_expired_timestamps(timestamps, now)
timestamps = self.event_timestamps[session_id]
if len(timestamps) < self.rate_limit_count:
timestamps.append(now)
break
else:
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds() + 0.3
self._remove_expired_timestamps(timestamps, now)
match self.rl_strategy:
case RateLimitStrategy.STALL.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
)
await asyncio.sleep(stall_duration)
now = datetime.now()
case RateLimitStrategy.DISCARD.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
)
return event.stop_event()
if len(timestamps) >= self.rate_limit_count:
# 达到限流阈值,计算下一个窗口的时间
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds()
match self.rl_strategy:
case RateLimitStrategy.STALL.value:
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
await asyncio.sleep(stall_duration)
case RateLimitStrategy.DISCARD.value:
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。")
return event.stop_event()
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
def _remove_expired_timestamps(
self, timestamps: Deque[datetime], now: datetime
) -> None:
timestamps.append(now)
return event.continue_event()
def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None:
"""
移除时间窗口外的时间戳。
@@ -95,4 +85,4 @@ class RateLimitStage(Stage):
"""
expiry_threshold: datetime = now - self.rate_limit_time
while timestamps and timestamps[0] < expiry_threshold:
timestamps.popleft()
timestamps.popleft()

Some files were not shown because too many files have changed in this diff Show More