Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
3b54a24037 feat: 初步实现可视化编辑 object 配置项 2025-02-22 17:12:35 +08:00
425 changed files with 9286 additions and 51022 deletions

3
.codecov.yml Normal file
View File

@@ -0,0 +1,3 @@
comment:
layout: "condensed_header, condensed_files, condensed_footer"
hide_project_coverage: TRUE

5
.coveragerc Normal file
View File

@@ -0,0 +1,5 @@
[run]
omit =
*/site-packages/*
*/dist-packages/*
your_package_name/tests/*

View File

@@ -17,8 +17,4 @@ ENV/
.conda/
README*.md
dashboard/
data/
changelogs/
tests/
.ruff_cache/
.astrbot
data/

15
.github/FUNDING.yml vendored
View File

@@ -1,15 +0,0 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: astrbot
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: ['https://afdian.com/a/astrbot_team']

View File

@@ -1,31 +0,0 @@
---
name: '🥳 发布插件'
title: "[Plugin] 插件名"
about: 提交插件到插件市场
labels: [ "plugin-publish" ]
assignees: ''
---
欢迎发布插件到插件市场!
## 插件基本信息
请将插件信息填写到下方的 Json 代码块中。`tags`(插件标签)和 `social_link`(社交链接)选填。
```json
{
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
## 检查
- [ ] 我的插件经过完整的测试
- [ ] 我的插件不包含恶意代码
- [ ] 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。

View File

@@ -28,7 +28,7 @@ body:
- type: textarea
attributes:
label: AstrBot 版本部署方式(如 Windows Docker Desktop 部署)、使用的提供商、使用的消息平台适配器
label: AstrBot 版本部署方式
description: >
请提供您的 AstrBot 版本和部署方式。
placeholder: >
@@ -53,9 +53,9 @@ body:
- type: textarea
attributes:
label: 报错日志
label: 额外信息
description: >
如报错日志、截图等。请提供完整的 Debug 级别的日志,不要介意它很长!
任何额外信息,如报错日志、截图等。
placeholder: >
请提供完整的报错日志或截图。
validations:
@@ -65,7 +65,7 @@ body:
attributes:
label: 你愿意提交 PR 吗?
description: >
这不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
绝对不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
options:
- label: 是的,我愿意提交 PR!
@@ -79,4 +79,4 @@ body:
- type: markdown
attributes:
value: "感谢您填写我们的表单!"
value: "感谢您填写我们的表单!"

View File

@@ -1,5 +1,5 @@
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
解决#XYZ
修复#XYZ
### Motivation
@@ -8,12 +8,3 @@
### Modifications
<!--简单解释你的改动-->
### Check
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 👀 我的更改经过良好的测试
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。
- [ ] 😮 我的更改没有引入恶意代码

View File

@@ -1,13 +0,0 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly

View File

@@ -7,7 +7,7 @@ on:
name: Auto Release
jobs:
build-and-publish-to-github-release:
build:
runs-on: ubuntu-latest
permissions:
contents: write
@@ -23,70 +23,13 @@ jobs:
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Upload to Cloudflare R2
env:
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
R2_BUCKET_NAME: "astrbot"
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
VERSION_TAG: ${{ github.ref_name }}
run: |
echo "Installing rclone..."
curl https://rclone.org/install.sh | sudo bash
echo "Configuring rclone remote..."
mkdir -p ~/.config/rclone
cat <<EOF > ~/.config/rclone/rclone.conf
[r2]
type = s3
provider = Cloudflare
access_key_id = $R2_ACCESS_KEY_ID
secret_access_key = $R2_SECRET_ACCESS_KEY
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
EOF
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
- name: Create GitHub Release
- name: Create Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"
build-and-publish-to-pypi:
# 构建并发布到 PyPI
runs-on: ubuntu-latest
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install uv
run: |
python -m pip install uv
- name: Build package
run: |
uv build
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
uv publish
artifacts: "dashboard/dist.zip"

View File

@@ -1,6 +1,6 @@
name: Run tests and upload coverage
on:
on:
push:
branches:
- master
@@ -8,7 +8,6 @@ on:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
@@ -22,24 +21,25 @@ jobs:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-asyncio pytest-cov
pip install --editable .
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
- name: Run tests
run: |
mkdir -p data/plugins
mkdir -p data/config
mkdir -p data/temp
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1,35 +0,0 @@
name: AstrBot Dashboard CI
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: npm install, build
run: |
cd dashboard
npm install
npm run build
- name: Inject Commit SHA
id: get_sha
run: |
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
mkdir -p dashboard/dist/assets
echo $COMMIT_SHA > dashboard/dist/assets/version
- name: Archive production artifacts
uses: actions/upload-artifact@v4
with:
name: dist-without-markdown
path: |
dashboard/dist
!dist/**/*.md

View File

@@ -11,42 +11,24 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Pull The Codes
uses: actions/checkout@v4
- name: 拉取源码
uses: actions/checkout@v3
with:
fetch-depth: 0 # Must be 0 so we can fetch tags
fetch-depth: 1
- name: Get latest tag (only on manual trigger)
id: get-latest-tag
if: github.event_name == 'workflow_dispatch'
run: |
tag=$(git describe --tags --abbrev=0)
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
- name: Checkout to latest tag (only on manual trigger)
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Set QEMU
- name: 设置 QEMU
uses: docker/setup-qemu-action@v3
- name: Set Docker Buildx
- name: 设置 Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to DockerHub
- name: 登录到 DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: Soulter
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Docker to DockerHub and Github GHCR
- name: 构建和推送 Docker hub
uses: docker/build-push-action@v6
with:
context: .
@@ -54,9 +36,8 @@ jobs:
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
ghcr.io/soulter/astrbot:latest
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v9
- uses: actions/stale@v5
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'

11
.gitignore vendored
View File

@@ -1,8 +1,6 @@
__pycache__
botpy.log
.vscode
.venv*
.idea
data_v2.db
data_v3.db
configs/session
@@ -19,15 +17,12 @@ addons/plugins
tests/astrbot_plugin_openai
chroma
dashboard/node_modules/
dashboard/dist/
node_modules/
.DS_Store
package-lock.json
package.json
venv/*
packages/python_interpreter/workplace
.venv/*
.conda/
.idea
pytest.ini
.astrbot
.conda/

View File

@@ -1,13 +0,0 @@
default_install_hook_types: [pre-commit, prepare-commit-msg]
ci:
autofix_commit_msg: ":balloon: auto fixes by pre-commit hooks"
autofix_prs: true
autoupdate_branch: master
autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format

View File

@@ -1 +0,0 @@
3.10

View File

@@ -1,35 +1,22 @@
FROM python:3.11-slim
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
nodejs \
npm \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
RUN python -m pip install -r requirements.txt --no-cache-dir
# 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
EXPOSE 6185
EXPOSE 6186
CMD [ "python", "main.py" ]

View File

@@ -1,35 +0,0 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]

View File

@@ -629,8 +629,8 @@ to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
AstrBot is a llm-powered chatbot and develop framework.
Copyright (C) 2022-2099 Soulter
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published

218
README.md
View File

@@ -1,6 +1,6 @@
<p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
@@ -10,16 +10,14 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7日消息量&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
@@ -27,42 +25,19 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
> [!WARNING]
>
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
## ✨ 近期更新
<details><summary>1. AstrBot 现已自带知识库能力</summary>
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
![image](https://github.com/user-attachments/assets/28b639b0-bb5c-4958-8e94-92ae8cfd1ab4)
</details>
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能
> [!NOTE]
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper
2. **多消息平台接入**。支持接入 QQOneBot、QQ 官方机器人平台、QQ 频道、企业微信、微信公众号、飞书、Telegram钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
2. **多消息平台接入**。支持接入 QQOneBot、QQ 频道、微信Gewechat、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP]
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM无法在聊天页使用大模型。不要再修改 demo 的登录密码了 😭)
## ✨ 使用方式
@@ -72,165 +47,96 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
#### Windows 一键安装器部署
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
需要电脑上安装有 Python>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### 宝塔面板部署
#### Replit 部署
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS 部署
社区贡献的部署方式。
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### 手动部署
> 推荐使用 `uv`
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)
首先,安装 uv
```bash
pip install uv
```
通过 Git Clone 安装 AstrBot
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
或者,直接通过 uvx 安装 AstrBot
```bash
mkdir astrbot && cd astrbot
uvx astrbot init
# uvx astrbot run
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### 在 Replit 上部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### 在 雨云 上部署
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
## ⚡ 消息平台支持情况
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企业微信 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | |
| Discord | |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| 微信对话开放平台 | 🚧 |
| WhatsApp | 🚧 |
| 小爱音响 | 🚧 |
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
| 飞书 | ✔ | 群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
## ⚡ 提供商支持情况
# 🦌 接下来的路线图
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
> [!TIP]
> 欢迎在 Issue 提出更多建议 <3
- [ ] 完善并保证目前所有平台适配器的功能一致性
- [ ] 优化插件接口
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
- [ ] 完善“聊天增强”部分,支持持久化记忆
- [ ] 规划 i18n
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
### 如何贡献
你可以通过查看问题或帮助审核 PR拉取请求来贡献。任何问题或 PR 都欢迎参与,以促进社区贡献。当然,这些只是建议,你可以以任何方式进行贡献。对于新功能的添加,请先通过 Issue 讨论。
### 开发环境
AstrBot 使用 `ruff` 进行代码格式化和检查。
```bash
git clone https://github.com/Soulter/AstrBot
pip install pre-commit
pre-commit install
```
对于新功能的添加,请先通过 Issue 讨论。
## 🌟 支持
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
## ✨ Demo
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
> [!NOTE]
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器Beta 测试)✨_
_✨基于 Docker 的沙箱化代码执行器Beta 测试)✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_WebUI_
_管理面板_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat在线与机器人交互 ✨_
</div>
</details>
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
此外,本项目的诞生离不开以下开源项目:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History
> [!TIP]
@@ -242,7 +148,21 @@ _✨ WebUI ✨_
</div>
![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097)
## Disclaimer
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
_私は、高性能ですから!_

View File

@@ -1,182 +0,0 @@
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
<div align="center">
_✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.app/">Documentation</a>
<a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
</div>
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
## ✨ Key Features
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
> [!TIP]
> Dashboard Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
> Username: `astrbot`, Password: `astrbot` (LLM not configured for chat page)
## ✨ Deployment
#### Docker Deployment
See docs: [Deploy with Docker](https://astrbot.app/deploy/astrbot/docker.html#docker-deployment)
#### Windows Installer
Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app/deploy/astrbot/windows.html)
#### Replit Deployment
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS Deployment
Community-contributed method.
See docs: [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html)
#### Manual Deployment
See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
## ⚡ Platform Support
| Platform | Status | Details | Message Types |
| -------------------------------------------------------------- | ------ | ------------------- | ------------------- |
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| Feishu | ✔ | Group chats | Text, Images |
| WeChat Open Platform | 🚧 | Planned | - |
| Discord | 🚧 | Planned | - |
| WhatsApp | 🚧 | Planned | - |
| Xiaomi Speakers | 🚧 | Planned | - |
## Provider Support Status
| Name | Support | Type | Notes |
|---------------------------|---------|------------------------|-----------------------------------------------------------------------|
| OpenAI API | ✔ | Text Generation | Supports all OpenAI API-compatible services including DeepSeek, Google Gemini, GLM, Moonshot, Alibaba Cloud Bailian, Silicon Flow, xAI, etc. |
| Claude API | ✔ | Text Generation | |
| Google Gemini API | ✔ | Text Generation | |
| Dify | ✔ | LLMOps | |
| DashScope (Alibaba Cloud) | ✔ | LLMOps | |
| Ollama | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
| LM Studio | ✔ | Model Loader | Local deployment for open-source LLMs (DeepSeek, Llama, etc.) |
| LLMTuner | ✔ | Model Loader | Local loading of fine-tuned models (e.g. LoRA) |
| OneAPI | ✔ | LLM Distribution | |
| Whisper | ✔ | Speech-to-Text | Supports API and local deployment |
| SenseVoice | ✔ | Speech-to-Text | Local deployment |
| OpenAI TTS API | ✔ | Text-to-Speech | |
| Fishaudio | ✔ | Text-to-Speech | Project involving GPT-Sovits author |
# 🦌 Roadmap
> [!TIP]
> Suggestions welcome via Issues <3
- [ ] Ensure feature parity across all platform adapters
- [ ] Optimize plugin APIs
- [ ] Add default TTS services (e.g., GPT-Sovits)
- [ ] Enhance chat features with persistent memory
- [ ] i18n Planning
## ❤️ Contributions
All Issues/PRs welcome! Simply submit your changes to this project :)
For major features, please discuss via Issues first.
## 🌟 Support
- Star this project!
- Support via [Afdian](https://afdian.com/a/soulter)
- WeChat support: [QR Code](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)
## ✨ Demos
> [!NOTE]
> Code executor file I/O currently tested with Napcat(QQ)/Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨ Docker-based Sandboxed Code Executor (Beta) ✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ Multimodal Input, Web Search, Text-to-Image ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ Natural Language TODO Lists ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ Plugin System Showcase ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ Web Dashboard ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ Built-in Web Chat Interface ✨_
</div>
## ⭐ Star History
> [!TIP]
> If this project helps you, please give it a star <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
## Disclaimer
1. Licensed under `AGPL-v3`.
2. WeChat integration uses [Gewechat](https://github.com/Devo919/Gewechat). Use at your own risk with non-critical accounts.
3. Users must comply with local laws and regulations.
<!-- ## ✨ ATRI [Beta]
Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
2. Long-term memory
3. Meme understanding & responses
4. TTS integration
-->
_私は、高性能ですから!_

View File

@@ -1,5 +1,5 @@
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
@@ -27,15 +27,15 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
## ✨ 主な機能
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechatFeishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
> [!TIP]
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
>
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
## ✨ 使用方法
@@ -136,11 +136,11 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## ⭐ Star History
> [!TIP]
> [!TIP]
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
@@ -152,7 +152,8 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## 免責事項
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください
2. WeChat個人アカウントのデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
<!-- ## ✨ ATRI [ベータテスト]
@@ -164,4 +165,6 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
4. TTS
-->
_私は、高性能ですから!_

View File

@@ -1,3 +1,2 @@
from .core.log import LogManager
logger = LogManager.GetLogger(log_name="astrbot")
logger = LogManager.GetLogger(log_name='astrbot')

View File

@@ -4,4 +4,10 @@ from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]
__all__ = [
"AstrBotConfig",
"logger",
"html_renderer",
"llm_tool",
"sp"
]

View File

@@ -1,3 +1,4 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core import html_renderer
@@ -5,11 +6,8 @@ from astrbot.core.star.register import register_llm_tool as llm_tool
# event
from astrbot.core.message.message_event_result import (
MessageEventResult,
MessageChain,
CommandResult,
EventResultType,
)
MessageEventResult, MessageChain, CommandResult, EventResultType
)
from astrbot.core.platform import AstrMessageEvent
# star register
@@ -20,16 +18,10 @@ from astrbot.core.star.register import (
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
)
from astrbot.core.star.filter.event_message_type import (
EventMessageTypeFilter,
EventMessageType,
)
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterTypeFilter,
PlatformAdapterType,
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from astrbot.core.star.register import (
register_star as register, # 注册插件Star
register_star as register # 注册插件Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *
@@ -40,14 +32,9 @@ from astrbot.core.provider import Provider, Personality, ProviderMetaData
# platform
from astrbot.core.platform import (
AstrMessageEvent,
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from .message_components import *
from .message_components import *

View File

@@ -6,44 +6,36 @@ from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent,
register_after_message_sent as after_message_sent
)
from astrbot.core.star.filter.event_message_type import (
EventMessageTypeFilter,
EventMessageType,
)
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterTypeFilter,
PlatformAdapterType,
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
from astrbot.core.star.filter.custom_filter import CustomFilter
__all__ = [
"command",
"command_group",
"event_message_type",
"regex",
"platform_adapter_type",
"permission_type",
"EventMessageTypeFilter",
"EventMessageType",
"PlatformAdapterTypeFilter",
"PlatformAdapterType",
"PermissionTypeFilter",
"CustomFilter",
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",
"after_message_sent",
"on_llm_response",
]
'command',
'command_group',
'event_message_type',
'regex',
'platform_adapter_type',
'permission_type',
'EventMessageTypeFilter',
'EventMessageType',
'PlatformAdapterTypeFilter',
'PlatformAdapterType',
'PermissionTypeFilter',
'CustomFilter',
'custom_filter',
'PermissionType',
'on_llm_request',
'llm_tool',
'on_decorating_result',
'after_message_sent',
'on_llm_response'
]

View File

@@ -1 +1 @@
from astrbot.core.message.components import *
from astrbot.core.message.components import *

View File

@@ -1,23 +1,6 @@
from astrbot.core.platform import (
AstrMessageEvent,
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
Group,
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
__all__ = [
"AstrMessageEvent",
"Platform",
"AstrBotMessage",
"MessageMember",
"MessageType",
"PlatformMetadata",
"register_platform_adapter",
"Group",
]
from astrbot.core.message.components import *

View File

@@ -1,17 +1,2 @@
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entities import (
ProviderRequest,
ProviderType,
ProviderMetaData,
LLMResponse,
)
__all__ = [
"Provider",
"STTProvider",
"Personality",
"ProviderRequest",
"ProviderType",
"ProviderMetaData",
"LLMResponse",
]
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse

View File

@@ -1,8 +1,6 @@
from astrbot.core.star.register import (
register_star as register, # 注册插件Star
register_star as register # 注册插件Star
)
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *
__all__ = ["register", "Context", "Star", "StarTools"]

View File

@@ -1,7 +0,0 @@
from astrbot.core.utils.session_waiter import (
SessionWaiter,
SessionController,
session_waiter,
)
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]

View File

@@ -1 +0,0 @@
__version__ = "3.5.23"

View File

@@ -1,59 +0,0 @@
"""
AstrBot CLI入口
"""
import click
import sys
from . import __version__
from .commands import init, run, plug, conf
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
@click.group()
@click.version_option(__version__, prog_name="AstrBot")
def cli() -> None:
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot CLI version: {__version__}")
@click.command()
@click.argument("command_name", required=False, type=str)
def help(command_name: str | None) -> None:
"""显示命令的帮助信息
如果提供了 COMMAND_NAME则显示该命令的详细帮助信息。
否则,显示通用帮助信息。
"""
ctx = click.get_current_context()
if command_name:
# 查找指定命令
command = cli.get_command(ctx, command_name)
if command:
# 显示特定命令的帮助信息
click.echo(command.get_help(ctx))
else:
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# 显示通用帮助信息
click.echo(cli.get_help(ctx))
cli.add_command(init)
cli.add_command(run)
cli.add_command(help)
cli.add_command(plug)
cli.add_command(conf)
if __name__ == "__main__":
cli()

View File

@@ -1,6 +0,0 @@
from .cmd_init import init
from .cmd_run import run
from .cmd_plug import plug
from .cmd_conf import conf
__all__ = ["init", "run", "plug", "conf"]

View File

@@ -1,206 +0,0 @@
import json
import click
import hashlib
import zoneinfo
from typing import Any, Callable
from ..utils import get_astrbot_root, check_astrbot_root
def _validate_log_level(value: str) -> str:
"""验证日志级别"""
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
)
return value
def _validate_dashboard_port(value: str) -> int:
"""验证 Dashboard 端口"""
try:
port = int(value)
if port < 1 or port > 65535:
raise click.ClickException("端口必须在 1-65535 范围内")
return port
except ValueError:
raise click.ClickException("端口必须是数字")
def _validate_dashboard_username(value: str) -> str:
"""验证 Dashboard 用户名"""
if not value:
raise click.ClickException("用户名不能为空")
return value
def _validate_dashboard_password(value: str) -> str:
"""验证 Dashboard 密码"""
if not value:
raise click.ClickException("密码不能为空")
return hashlib.md5(value.encode()).hexdigest()
def _validate_timezone(value: str) -> str:
"""验证时区"""
try:
zoneinfo.ZoneInfo(value)
except Exception:
raise click.ClickException(f"无效的时区: {value}请使用有效的IANA时区名称")
return value
def _validate_callback_api_base(value: str) -> str:
"""验证回调接口基址"""
if not value.startswith("http://") and not value.startswith("https://"):
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
return value
# 可通过CLI设置的配置项配置键到验证器函数的映射
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
"timezone": _validate_timezone,
"log_level": _validate_log_level,
"dashboard.port": _validate_dashboard_port,
"dashboard.username": _validate_dashboard_username,
"dashboard.password": _validate_dashboard_password,
"callback_api_base": _validate_callback_api_base,
}
def _load_config() -> dict[str, Any]:
"""加载或初始化配置文件"""
root = get_astrbot_root()
if not check_astrbot_root(root):
raise click.ClickException(
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
config_path = root / "data" / "cmd_config.json"
if not config_path.exists():
from astrbot.core.config.default import DEFAULT_CONFIG
config_path.write_text(
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
try:
return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e:
raise click.ClickException(f"配置文件解析失败: {str(e)}")
def _save_config(config: dict[str, Any]) -> None:
"""保存配置文件"""
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
)
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
"""设置嵌套字典中的值"""
parts = path.split(".")
for part in parts[:-1]:
if part not in obj:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
)
obj = obj[part]
obj[parts[-1]] = value
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
"""获取嵌套字典中的值"""
parts = path.split(".")
for part in parts:
obj = obj[part]
return obj
@click.group(name="conf")
def conf():
"""配置管理命令
支持的配置项:
- timezone: 时区设置 (例如: Asia/Shanghai)
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- dashboard.port: Dashboard 端口
- dashboard.username: Dashboard 用户名
- dashboard.password: Dashboard 密码
- callback_api_base: 回调接口基址
"""
pass
@conf.command(name="set")
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str):
"""设置配置项的值"""
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
config = _load_config()
try:
old_value = _get_nested_item(config, key)
validated_value = CONFIG_VALIDATORS[key](value)
_set_nested_item(config, key, validated_value)
_save_config(config)
click.echo(f"配置已更新: {key}")
if key == "dashboard.password":
click.echo(" 原值: ********")
click.echo(" 新值: ********")
else:
click.echo(f" 原值: {old_value}")
click.echo(f" 新值: {validated_value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"设置配置失败: {str(e)}")
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str = None):
"""获取配置项的值不提供key则显示所有可配置项"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
try:
value = _get_nested_item(config, key)
if key == "dashboard.password":
value = "********"
click.echo(f"{key}: {value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"获取配置失败: {str(e)}")
else:
click.echo("当前配置:")
for key in CONFIG_VALIDATORS.keys():
try:
value = (
"********"
if key == "dashboard.password"
else _get_nested_item(config, key)
)
click.echo(f" {key}: {value}")
except (KeyError, TypeError):
pass

View File

@@ -1,55 +0,0 @@
import asyncio
import click
from filelock import FileLock, Timeout
from ..utils import check_dashboard, get_astrbot_root
async def initialize_astrbot(astrbot_root) -> None:
"""执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
click.echo(f"Current Directory: {astrbot_root}")
click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
)
if click.confirm(
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
paths = {
"data": astrbot_root / "data",
"config": astrbot_root / "data" / "config",
"plugins": astrbot_root / "data" / "plugins",
"temp": astrbot_root / "data" / "temp",
}
for name, path in paths.items():
path.mkdir(parents=True, exist_ok=True)
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
await check_dashboard(astrbot_root / "data")
@click.command()
def init() -> None:
"""初始化 AstrBot"""
click.echo("Initializing AstrBot...")
astrbot_root = get_astrbot_root()
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
try:
with lock.acquire():
asyncio.run(initialize_astrbot(astrbot_root))
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"初始化失败: {e!s}")

View File

@@ -1,247 +0,0 @@
import re
from pathlib import Path
import click
import shutil
from ..utils import (
get_git_repo,
build_plug_list,
manage_plugin,
PluginStatus,
check_astrbot_root,
get_astrbot_root,
)
@click.group()
def plug():
"""插件管理"""
pass
def _get_data_path() -> Path:
base = get_astrbot_root()
if not check_astrbot_root(base):
raise click.ClickException(
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
return (base / "data").resolve()
def display_plugins(plugins, title=None, color=None):
if title:
click.echo(click.style(title, fg=color, bold=True))
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
click.echo("-" * 85)
for p in plugins:
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
click.echo(
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
f"{p['author']:<15} {desc:<30}"
)
@plug.command()
@click.argument("name")
def new(name: str):
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
if plug_path.exists():
raise click.ClickException(f"插件 {name} 已存在")
author = click.prompt("请输入插件作者", type=str)
desc = click.prompt("请输入插件描述", type=str)
version = click.prompt("请输入插件版本", type=str)
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
repo = click.prompt("请输入插件仓库:", type=str)
if not repo.startswith("http"):
raise click.ClickException("仓库地址必须以 http 开头")
click.echo("下载插件模板...")
get_git_repo(
"https://github.com/Soulter/helloworld",
plug_path,
)
click.echo("重写插件信息...")
# 重写 metadata.yaml
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
f.write(
f"name: {name}\n"
f"desc: {desc}\n"
f"version: {version}\n"
f"author: {author}\n"
f"repo: {repo}\n"
)
# 重写 README.md
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
# 重写 main.py
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
content = f.read()
new_content = content.replace(
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
f'@register("{name}", "{author}", "{desc}", "{version}")',
)
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
f.write(new_content)
click.echo(f"插件 {name} 创建成功")
@plug.command()
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
def list(all: bool):
"""列出插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
# 未发布的插件
not_published_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
]
if not_published_plugins:
display_plugins(not_published_plugins, "未发布的插件", "red")
# 需要更新的插件
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if need_update_plugins:
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
# 已安装的插件
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
if installed_plugins:
display_plugins(installed_plugins, "已安装的插件", "green")
# 未安装的插件
not_installed_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
]
if not_installed_plugins and all:
display_plugins(not_installed_plugins, "未安装的插件", "blue")
if (
not any([not_published_plugins, need_update_plugins, installed_plugins])
and not all
):
click.echo("未安装任何插件")
@plug.command()
@click.argument("name")
@click.option("--proxy", help="代理服务器地址")
def install(name: str, proxy: str | None):
"""安装插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
plugin = next(
(
p
for p in plugins
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
),
None,
)
if not plugin:
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
@plug.command()
@click.argument("name")
def remove(name: str):
"""卸载插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
plugin = next((p for p in plugins if p["name"] == name), None)
if not plugin or not plugin.get("local_path"):
raise click.ClickException(f"插件 {name} 不存在或未安装")
plugin_path = plugin["local_path"]
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
try:
shutil.rmtree(plugin_path)
click.echo(f"插件 {name} 已卸载")
except Exception as e:
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
@plug.command()
@click.argument("name", required=False)
@click.option("--proxy", help="Github代理地址")
def update(name: str, proxy: str | None):
"""更新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
if name:
plugin = next(
(
p
for p in plugins
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
),
None,
)
if not plugin:
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
else:
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if not need_update_plugins:
click.echo("没有需要更新的插件")
return
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
for plugin in need_update_plugins:
plugin_name = plugin["name"]
click.echo(f"正在更新插件 {plugin_name}...")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@plug.command()
@click.argument("query")
def search(query: str):
"""搜索插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
matched_plugins = [
p
for p in plugins
if query.lower() in p["name"].lower()
or query.lower() in p["desc"].lower()
or query.lower() in p["author"].lower()
]
if not matched_plugins:
click.echo(f"未找到匹配 '{query}' 的插件")
return
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")

View File

@@ -1,63 +0,0 @@
import os
import sys
from pathlib import Path
import click
import asyncio
import traceback
from filelock import FileLock, Timeout
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
async def run_astrbot(astrbot_root: Path):
"""运行 AstrBot"""
from astrbot.core import logger, LogManager, LogBroker, db_helper
from astrbot.core.initial_loader import InitialLoader
await check_dashboard(astrbot_root / "data")
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
db = db_helper
core_lifecycle = InitialLoader(db, log_broker)
await core_lifecycle.start()
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
@click.command()
def run(reload: bool, port: str) -> None:
"""运行 AstrBot"""
try:
os.environ["ASTRBOT_CLI"] = "1"
astrbot_root = get_astrbot_root()
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
sys.path.insert(0, str(astrbot_root))
if port:
os.environ["DASHBOARD_PORT"] = port
if reload:
click.echo("启用插件自动重载")
os.environ["ASTRBOT_RELOAD"] = "1"
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
with lock.acquire():
asyncio.run(run_astrbot(astrbot_root))
except KeyboardInterrupt:
click.echo("AstrBot 已关闭...")
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")

View File

@@ -1,18 +0,0 @@
from .basic import (
get_astrbot_root,
check_astrbot_root,
check_dashboard,
)
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
from .version_comparator import VersionComparator
__all__ = [
"get_astrbot_root",
"check_astrbot_root",
"check_dashboard",
"get_git_repo",
"manage_plugin",
"build_plug_list",
"VersionComparator",
"PluginStatus",
]

View File

@@ -1,67 +0,0 @@
from pathlib import Path
import click
def check_astrbot_root(path: str | Path) -> bool:
"""检查路径是否为 AstrBot 根目录"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
return False
if not (path / ".astrbot").exists():
return False
return True
def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径"""
return Path.cwd()
async def check_dashboard(astrbot_root: Path) -> None:
"""检查是否安装了dashboard"""
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
from astrbot.core.config.default import VERSION
from .version_comparator import VersionComparator
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo("未安装管理面板")
if click.confirm(
"是否安装管理面板?",
default=True,
abort=True,
):
click.echo("正在安装管理面板...")
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
click.echo("管理面板安装完成")
case str():
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("管理面板已是最新版本")
return
else:
try:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return
except FileNotFoundError:
click.echo("初始化管理面板目录...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
)
click.echo("管理面板初始化完成")
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return

View File

@@ -1,230 +0,0 @@
import shutil
import tempfile
import httpx
import yaml
from enum import Enum
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile
import click
from .version_comparator import VersionComparator
class PluginStatus(str, Enum):
INSTALLED = "已安装"
NEED_UPDATE = "需更新"
NOT_INSTALLED = "未安装"
NOT_PUBLISHED = "未发布"
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
"""从 Git 仓库下载代码并解压到指定路径"""
temp_dir = Path(tempfile.mkdtemp())
try:
# 解析仓库信息
repo_namespace = url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
# 尝试获取最新的 release
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try:
with httpx.Client(
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(release_url)
resp.raise_for_status()
releases = resp.json()
if releases:
# 使用最新的 release
download_url = releases[0]["zipball_url"]
else:
# 没有 release使用默认分支
click.echo(f"正在从默认分支下载 {author}/{repo}")
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
except Exception as e:
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
download_url = url
# 应用代理
if proxy:
download_url = f"{proxy}/{download_url}"
# 下载并解压
with httpx.Client(
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(download_url)
if (
resp.status_code == 404
and "archive/refs/heads/master.zip" in download_url
):
alt_url = download_url.replace("master.zip", "main.zip")
click.echo("master 分支不存在,尝试下载 main 分支")
resp = client.get(alt_url)
resp.raise_for_status()
else:
resp.raise_for_status()
zip_content = BytesIO(resp.content)
with ZipFile(zip_content) as z:
z.extractall(temp_dir)
namelist = z.namelist()
root_dir = Path(namelist[0]).parts[0] if namelist else ""
if target_path.exists():
shutil.rmtree(target_path)
shutil.move(temp_dir / root_dir, target_path)
finally:
if temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)
def load_yaml_metadata(plugin_dir: Path) -> dict:
"""从 metadata.yaml 文件加载插件元数据
Args:
plugin_dir: 插件目录路径
Returns:
dict: 包含元数据的字典,如果读取失败则返回空字典
"""
yaml_path = plugin_dir / "metadata.yaml"
if yaml_path.exists():
try:
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
except Exception as e:
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
return {}
def build_plug_list(plugins_dir: Path) -> list:
"""构建插件列表,包含本地和在线插件信息
Args:
plugins_dir (Path): 插件目录路径
Returns:
list: 包含插件信息的字典列表
"""
# 获取本地插件信息
result = []
if plugins_dir.exists():
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
# 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
result.append({
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
})
# 获取在线插件列表
online_plugins = []
try:
with httpx.Client() as client:
resp = client.get("https://api.soulter.top/astrbot/plugins")
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append({
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
})
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)
# 与在线插件比对,更新状态
online_plugin_names = {plugin["name"] for plugin in online_plugins}
for local_plugin in result:
if local_plugin["name"] in online_plugin_names:
# 查找对应的在线插件
online_plugin = next(
p for p in online_plugins if p["name"] == local_plugin["name"]
)
if (
VersionComparator.compare_version(
local_plugin["version"], online_plugin["version"]
)
< 0
):
local_plugin["status"] = PluginStatus.NEED_UPDATE
else:
# 本地插件未在线上发布
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
# 添加未安装的在线插件
for online_plugin in online_plugins:
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
result.append(online_plugin)
return result
def manage_plugin(
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
) -> None:
"""安装或更新插件
Args:
plugin (dict): 插件信息字典
plugins_dir (Path): 插件目录
is_update (bool, optional): 是否为更新操作. 默认为 False
proxy (str, optional): 代理服务器地址
"""
plugin_name = plugin["name"]
repo_url = plugin["repo"]
# 如果是更新且有本地路径,直接使用本地路径
if is_update and plugin.get("local_path"):
target_path = Path(plugin["local_path"])
else:
target_path = plugins_dir / plugin_name
backup_path = Path(f"{target_path}_backup") if is_update else None
# 检查插件是否存在
if is_update and not target_path.exists():
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
# 备份现有插件
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
if is_update:
shutil.copytree(target_path, backup_path)
try:
click.echo(
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
)
get_git_repo(repo_url, target_path, proxy)
# 更新成功,删除备份
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
)

View File

@@ -1,92 +0,0 @@
"""
拷贝自 astrbot.core.utils.version_comparator
"""
import re
class VersionComparator:
@staticmethod
def compare_version(v1: str, v2: str) -> int:
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
参考: https://semver.org/lang/zh-CN/
返回 1 表示 v1 > v2返回 -1 表示 v1 < v2返回 0 表示 v1 = v2。
"""
v1 = v1.lower().replace("v", "")
v2 = v2.lower().replace("v", "")
def split_version(version):
match = re.match(
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
version,
)
if not match:
return [], None
major_minor_patch = match.group(1).split(".")
prerelease = match.group(2)
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
parts = [int(x) for x in major_minor_patch]
prerelease = VersionComparator._split_prerelease(prerelease)
return parts, prerelease
v1_parts, v1_prerelease = split_version(v1)
v2_parts, v2_prerelease = split_version(v2)
# 比较数字部分
length = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (length - len(v1_parts)))
v2_parts.extend([0] * (length - len(v2_parts)))
for i in range(length):
if v1_parts[i] > v2_parts[i]:
return 1
elif v1_parts[i] < v2_parts[i]:
return -1
# 比较预发布标签
if v1_prerelease is None and v2_prerelease is not None:
return 1 # 没有预发布标签的版本高于有预发布标签的版本
elif v1_prerelease is not None and v2_prerelease is None:
return -1 # 有预发布标签的版本低于没有预发布标签的版本
elif v1_prerelease is not None and v2_prerelease is not None:
len_pre = max(len(v1_prerelease), len(v2_prerelease))
for i in range(len_pre):
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
if p1 is None and p2 is not None:
return -1
elif p1 is not None and p2 is None:
return 1
elif isinstance(p1, int) and isinstance(p2, str):
return -1
elif isinstance(p1, str) and isinstance(p2, int):
return 1
elif isinstance(p1, int) and isinstance(p2, int):
if p1 > p2:
return 1
elif p1 < p2:
return -1
elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
elif p1 < p2:
return -1
return 0 # 预发布标签完全相同
return 0 # 数字部分和预发布标签都相同
@staticmethod
def _split_prerelease(prerelease):
if not prerelease:
return None
parts = prerelease.split(".")
result = []
for part in parts:
if part.isdigit():
result.append(int(part))
else:
result.append(part)
return result

View File

@@ -1,30 +1,26 @@
import os
import asyncio
from .log import LogManager, LogBroker # noqa
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
from astrbot.core.file_token_service import FileTokenService
from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
DEMO_MODE = os.getenv("DEMO_MODE", False)
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
t2i_base_url = astrbot_config.get('t2i_endpoint', 'https://t2i.soulter.top/text2img')
html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
logger = LogManager.GetLogger(log_name='astrbot')
if os.environ.get('TESTING', ""):
logger.setLevel('DEBUG')
db_helper = SQLiteDatabase(DB_PATH)
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences()
# 文件令牌服务
file_token_service = FileTokenService()
pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None),
)
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"

View File

@@ -1,9 +1,2 @@
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
from .astrbot_config import *
__all__ = [
"DEFAULT_CONFIG",
"VERSION",
"DB_PATH",
"AstrBotConfig",
]
from .astrbot_config import *

View File

@@ -4,158 +4,115 @@ import logging
import enum
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
logger = logging.getLogger("astrbot")
class RateLimitStrategy(enum.Enum):
STALL = "stall"
DISCARD = "discard"
class AstrBotConfig(dict):
"""从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
'''从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
- 如果传入了 schema将会通过 schema 解析出 default_config此时传入的 default_config 会被忽略。
"""
'''
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None,
schema: dict = None
):
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
object.__setattr__(self, "config_path", config_path)
object.__setattr__(self, "default_config", default_config)
object.__setattr__(self, "schema", schema)
object.__setattr__(self, 'config_path', config_path)
object.__setattr__(self, 'default_config', default_config)
object.__setattr__(self, 'schema', schema)
if schema:
default_config = self._config_schema_to_default_config(schema)
if not self.check_exist():
"""不存在时载入默认配置"""
'''不存在时载入默认配置'''
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(default_config, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def _config_schema_to_default_config(self, schema: dict) -> dict:
"""将 Schema 转换成 Config"""
'''将 Schema 转换成 Config'''
conf = {}
def _parse_schema(schema: dict, conf: dict):
for k, v in schema.items():
if v["type"] not in DEFAULT_VALUE_MAP:
raise TypeError(
f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}"
)
if "default" in v:
default = v["default"]
if v['type'] not in DEFAULT_VALUE_MAP:
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
if 'default' in v:
default = v['default']
else:
default = DEFAULT_VALUE_MAP[v["type"]]
if v["type"] == "object":
default = DEFAULT_VALUE_MAP[v['type']]
if v['type'] == 'object':
conf[k] = {}
_parse_schema(v["items"], conf[k])
_parse_schema(v['items'], conf[k])
else:
conf[k] = default
_parse_schema(schema, conf)
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
# 创建一个新的有序字典以保持参考配置的顺序
new_conf = {}
# 先按照参考配置的顺序添加配置项
for key, value in refer_conf.items():
if key not in conf:
# 配置项不存在,插入默认值
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,插入默认值 {value}")
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
new_conf[key] = value
conf[key] = value
has_new = True
else:
if conf[key] is None:
# 配置项为 None使用默认值
new_conf[key] = value
conf[key] = value
has_new = True
elif isinstance(value, dict):
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
new_conf[key] = value
has_new = True
else:
# 递归检查并同步顺序
child_has_new = self.check_config_integrity(
value, conf[key], path + "." + key if path else key
)
new_conf[key] = conf[key]
has_new |= child_has_new
else:
# 直接使用现有配置
new_conf[key] = conf[key]
# 检查是否存在参考配置中没有的配置项
for key in list(conf.keys()):
if key not in refer_conf:
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
has_new = True
# 顺序不一致也算作变更
if list(conf.keys()) != list(new_conf.keys()):
if path:
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
else:
logger.info("检查到配置项顺序不一致,已重新排序")
has_new = True
# 更新原始配置
conf.clear()
conf.update(new_conf)
has_new |= self.check_config_integrity(value, conf[key], path + "." + key if path else key)
return has_new
def save_config(self, replace_config: Dict = None):
"""将配置写入文件
'''将配置写入文件
如果传入 replace_config则将配置替换为 replace_config
"""
'''
if replace_config:
self.update(replace_config)
with open(self.config_path, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
return None
def __delattr__(self, key):
try:
del self[key]
@@ -167,4 +124,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(self.config_path)
return os.path.exists(self.config_path)

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,3 @@
"""
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
@@ -13,187 +6,104 @@ from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
class ConversationManager():
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
"""新建对话,并将当前会话的对话转移到新对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
'''新建对话,并将当前会话的对话转移到新对话'''
conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.db.new_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
'''切换会话的对话'''
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.db.delete_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
del self.session_conversations[unified_msg_origin]
sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
"""获取会话当前的对话 ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
'''获取会话当前的对话 ID'''
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(
self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation:
"""获取会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Returns:
conversation (Conversation): 对话对象
"""
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
return self.db.get_conversation_by_user_id(
unified_msg_origin, conversation_id
)
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
'''获取会话的对话'''
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversations (List[Conversation]): 对话对象列表
"""
'''获取会话的所有对话'''
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
):
"""更新会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
'''更新会话的对话'''
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
history=json.dumps(history),
history=json.dumps(history)
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
"""
'''更新会话的对话标题'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin, cid=conversation_id, title=title
user_id=unified_msg_origin,
cid=conversation_id,
title=title
)
async def update_conversation_persona_id(
self, unified_msg_origin: str, persona_id: str
):
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
"""
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
'''更新会话的对话 Persona ID'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
user_id=unified_msg_origin,
cid=conversation_id,
persona_id=persona_id
)
async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10
):
"""获取人类可读的上下文
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
contexts = []
temp_contexts = []
for record in history:
if record["role"] == "user":
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
if "content" in record and record["content"]:
temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record:
tool_calls_str = json.dumps(
record["tool_calls"], ensure_ascii=False
)
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else:
temp_contexts.append("Assistant: [未知的内容]")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
@@ -201,9 +111,9 @@ class ConversationManager:
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
return paged_contexts, total_pages

View File

@@ -1,14 +1,3 @@
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
工作流程:
1. 初始化所有组件
2. 启动事件总线和任务, 所有任务都在这里运行
3. 执行启动完成事件钩子
"""
import traceback
import asyncio
import time
@@ -28,179 +17,106 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle:
"""
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker # 初始化日志代理
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
# 设置代理
if self.astrbot_config.get("http_proxy", ""):
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
if proxy := os.environ.get("https_proxy"):
logger.debug(f"Using proxy: {proxy}")
os.environ["no_proxy"] = "localhost"
self.log_broker = log_broker
self.astrbot_config = astrbot_config
self.db = db
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
os.environ['no_proxy'] = 'localhost,127.0.0.1'
async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
logger.info("AstrBot v" + VERSION)
logger.info("AstrBot v"+ VERSION)
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
logger.setLevel("DEBUG")
else:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
# 初始化事件队列
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
# 初始化供应商管理器
self.event_queue.closed = False
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
# 初始化对话管理器
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.conversation_manager = ConversationManager(self.db)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.knowledge_db_manager
)
# 初始化插件管理器
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
# 扫描、注册插件、实例化插件类
await self.plugin_manager.reload()
# 根据配置实例化各个 Provider
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
'''根据配置实例化各个 Provider'''
await self.platform_manager.initialize()
'''根据配置实例化各个平台适配器'''
# 初始化消息事件流水线调度器
self.pipeline_scheduler = PipelineScheduler(
PipelineContext(self.astrbot_config, self.plugin_manager)
)
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
await self.pipeline_scheduler.initialize()
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
'''初始化消息事件流水线调度器'''
self.astrbot_updator = AstrBotUpdator(self.astrbot_config['plugin_repo_mirror'])
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
# 记录启动时间
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
def _load(self):
"""加载事件总线和任务并初始化"""
# 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task(
self.event_bus.dispatch(), name="event_bus"
)
# 把插件中注册的所有协程函数注册到事件总线中并执行
platform_tasks = self.load_platform()
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
tasks_ = [event_bus_task, *extra_tasks]
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
for task in tasks_:
self.curr_tasks.append(
asyncio.create_task(self._task_wrapper(task), name=task.get_name())
)
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task):
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
Args:
task (asyncio.Task): 要执行的异步任务
"""
try:
await task
except asyncio.CancelledError:
pass # 任务被取消, 静默处理
pass
except Exception as e:
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def start(self):
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
self._load()
logger.info("AstrBot 启动完成。")
# 执行启动完成事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAstrBotLoadedEvent
)
for handler in handlers:
try:
logger.info(
f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except BaseException:
logger.error(traceback.format_exc())
# 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self):
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
# 请求停止所有正在运行的异步任务
self.event_queue.closed = True
for task in self.curr_tasks:
task.cancel()
for plugin in self.plugin_manager.context.get_all_stars():
try:
await self.plugin_manager._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
)
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
for task in self.curr_tasks:
try:
await task
@@ -208,22 +124,14 @@ class AstrBotCoreLifecycle:
pass
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
async def restart(self):
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def restart(self):
self.event_queue.closed = True
threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start()
def load_platform(self) -> List[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
)
return tasks
tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name))
return tasks

View File

@@ -1,161 +1,113 @@
import abc
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@dataclass
class BaseDatabase(abc.ABC):
"""
'''
数据库基类
"""
'''
def __init__(self) -> None:
pass
def insert_base_metrics(self, metrics: dict):
"""插入基础指标数据"""
self.insert_platform_metrics(metrics["platform_stats"])
self.insert_plugin_metrics(metrics["plugin_stats"])
self.insert_command_metrics(metrics["command_stats"])
self.insert_llm_metrics(metrics["llm_stats"])
'''插入基础指标数据'''
self.insert_platform_metrics(metrics['platform_stats'])
self.insert_plugin_metrics(metrics['plugin_stats'])
self.insert_command_metrics(metrics['command_stats'])
self.insert_llm_metrics(metrics['llm_stats'])
@abc.abstractmethod
def insert_platform_metrics(self, metrics: dict):
"""插入平台指标数据"""
'''插入平台指标数据'''
raise NotImplementedError
@abc.abstractmethod
def insert_plugin_metrics(self, metrics: dict):
"""插入插件指标数据"""
'''插入插件指标数据'''
raise NotImplementedError
@abc.abstractmethod
def insert_command_metrics(self, metrics: dict):
"""插入指令指标数据"""
'''插入指令指标数据'''
raise NotImplementedError
@abc.abstractmethod
def insert_llm_metrics(self, metrics: dict):
"""插入 LLM 指标数据"""
'''插入 LLM 指标数据'''
raise NotImplementedError
@abc.abstractmethod
def update_llm_history(self, session_id: str, content: str, provider_type: str):
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
'''更新 LLM 历史记录。当不存在 session_id 时插入'''
raise NotImplementedError
@abc.abstractmethod
def get_llm_history(
self, session_id: str = None, provider_type: str = None
) -> List[LLMHistory]:
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> List[LLMHistory]:
'''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有'''
raise NotImplementedError
@abc.abstractmethod
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据"""
'''获取基础统计数据'''
raise NotImplementedError
@abc.abstractmethod
def get_total_message_count(self) -> int:
"""获取总消息数"""
'''获取总消息数'''
raise NotImplementedError
@abc.abstractmethod
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据(合并)"""
'''获取基础统计数据(合并)'''
raise NotImplementedError
@abc.abstractmethod
def insert_atri_vision_data(self, vision_data: ATRIVision):
"""插入 ATRI 视觉数据"""
'''插入 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_atri_vision_data(self) -> List[ATRIVision]:
"""获取 ATRI 视觉数据"""
'''获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(
self, url_or_path: str, id: str
) -> ATRIVision:
"""通过 url 或 path 获取 ATRI 视觉数据"""
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
"""通过 user_id 和 cid 获取 Conversation"""
'''通过 user_id 和 cid 获取 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def new_conversation(self, user_id: str, cid: str):
"""新建 Conversation"""
'''新建 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新 Conversation"""
'''更新 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_conversation(self, user_id: str, cid: str):
"""删除 Conversation"""
'''删除 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
"""更新 Conversation 标题"""
'''更新 Conversation 标题'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID"""
raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
'''更新 Conversation Persona ID'''
raise NotImplementedError

View File

@@ -1,65 +1,48 @@
"""指标数据"""
'''指标数据'''
from dataclasses import dataclass, field
from typing import List
@dataclass
class Platform:
"""平台使用统计数据"""
class Platform():
name: str
count: int
timestamp: int
@dataclass
class Provider():
name: str
count: int
timestamp: int
@dataclass
class Plugin():
name: str
count: int
timestamp: int
@dataclass
class Command():
name: str
count: int
timestamp: int
@dataclass
class Provider:
"""供应商使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Plugin:
"""插件使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Command:
"""命令使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Stats:
class Stats():
platform: List[Platform] = field(default_factory=list)
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
@dataclass
class LLMHistory:
"""LLM 聊天时持久化的信息"""
class LLMHistory():
'''LLM 聊天时持久化的信息'''
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision:
"""Deprecated"""
class ATRIVision():
'''Deprecated'''
id: str
url_or_path: str
caption: str
@@ -69,21 +52,19 @@ class ATRIVision:
session_id: str
sender_nickname: str
timestamp: int = -1
@dataclass
class Conversation:
"""LLM 对话存储
class Conversation():
'''LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
'''
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
'''字符串格式的列表。'''
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
persona_id: str = ""

View File

@@ -1,32 +1,36 @@
import sqlite3
import os
import time
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
from astrbot.core.db.po import (
Platform,
Stats,
LLMHistory,
ATRIVision,
Conversation
)
from . import BaseDatabase
from typing import Tuple, List, Dict, Any
from typing import Tuple
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path
with open(
os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8"
) as f:
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
sql = f.read()
# 初始化数据库
self.conn = self._get_conn(self.db_path)
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
"""
'''
PRAGMA table_info(webchat_conversation)
"""
'''
)
res = c.fetchall()
has_title = False
@@ -38,26 +42,26 @@ class SQLiteDatabase(BaseDatabase):
has_persona_id = True
if not has_title:
c.execute(
"""
'''
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
"""
'''
)
self.conn.commit()
if not has_persona_id:
c.execute(
"""
'''
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
"""
'''
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn
try:
@@ -65,23 +69,22 @@ class SQLiteDatabase(BaseDatabase):
except sqlite3.ProgrammingError:
conn = self._get_conn(self.db_path)
c = conn.cursor()
if params:
c.execute(sql, params)
c.close()
else:
c.execute(sql)
c.close()
conn.commit()
def insert_platform_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
'''
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
''', (k, v, int(time.time()))
)
def insert_plugin_metrics(self, metrics: dict):
@@ -90,63 +93,57 @@ class SQLiteDatabase(BaseDatabase):
def insert_command_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
'''
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
''', (k, v, int(time.time()))
)
def insert_llm_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
'''
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
''', (k, v, int(time.time()))
)
def update_llm_history(self, session_id: str, content: str, provider_type: str):
res = self.get_llm_history(session_id, provider_type)
if res:
self._exec_sql(
"""
'''
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
""",
(content, session_id, provider_type),
''', (content, session_id, provider_type)
)
else:
self._exec_sql(
"""
'''
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
""",
(provider_type, session_id, content),
''', (provider_type, session_id, content)
)
def get_llm_history(
self, session_id: str = None, provider_type: str = None
) -> Tuple:
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
conditions = []
params = []
if session_id:
conditions.append("session_id = ?")
params.append(session_id)
if provider_type:
conditions.append("provider_type = ?")
params.append(provider_type)
sql = "SELECT * FROM llm_history"
if conditions:
sql += " WHERE " + " AND ".join(conditions)
c.execute(sql, params)
where_clause = ""
if session_id or provider_type:
where_clause += " WHERE "
has = False
if session_id:
where_clause += f"session_id = '{session_id}'"
has = True
if provider_type:
if has:
where_clause += " AND "
where_clause += f"provider_type = '{provider_type}'"
c.execute(
'''
SELECT * FROM llm_history
''' + where_clause
)
res = c.fetchall()
histories = []
for row in res:
@@ -155,134 +152,129 @@ class SQLiteDatabase(BaseDatabase):
return histories
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据"""
'''获取 offset_sec 秒前到现在的基础统计数据'''
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
'''
SELECT * FROM platform
"""
+ where_clause
''' + where_clause
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
# c.execute(
# '''
# SELECT * FROM command
# ''' + where_clause
# )
# command = []
# for row in c.fetchall():
# command.append(Command(*row))
# c.execute(
# '''
# SELECT * FROM llm
# ''' + where_clause
# )
# llm = []
# for row in c.fetchall():
# llm.append(Provider(*row))
c.close()
return Stats(platform, [], [])
def get_total_message_count(self) -> int:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
'''
SELECT SUM(count) FROM platform
"""
'''
)
res = c.fetchone()
c.close()
return res[0]
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
'''获取 offset_sec 秒前到现在的基础统计数据(合并)'''
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
'''
SELECT name, SUM(count), timestamp FROM platform
"""
+ where_clause
+ " GROUP BY name"
''' + where_clause + " GROUP BY name"
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
''', (user_id, cid)
)
res = c.fetchone()
c.close()
if not res:
return
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
"""
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
""",
(user_id, cid, history, updated_at, created_at),
''', (user_id, cid, history, updated_at, created_at)
)
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
'''
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
""",
(user_id,),
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
@@ -292,276 +284,82 @@ class SQLiteDatabase(BaseDatabase):
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
)
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新对话,并且同时更新时间"""
'''更新对话,并且同时更新时间'''
updated_at = int(time.time())
self._exec_sql(
"""
'''
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
""",
(history, updated_at, user_id, cid),
''', (history, updated_at, user_id, cid)
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
"""
'''
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
""",
(title, user_id, cid),
''', (title, user_id, cid)
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
"""
'''
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
""",
(persona_id, user_id, cid),
''', (persona_id, user_id, cid)
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
"""
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
ts = int(time.time())
keywords = ",".join(vision.keywords)
self._exec_sql(
"""
'''
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
vision.id,
vision.url_or_path,
vision.caption,
vision.is_meme,
keywords,
vision.platform_name,
vision.session_id,
vision.sender_nickname,
ts,
),
''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts)
)
def get_atri_vision_data(self) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
'''
SELECT * FROM atri_vision
"""
'''
)
res = c.fetchall()
visions = []
for row in res:
visions.append(ATRIVision(*row))
c.close()
return visions
def get_atri_vision_data_by_path_or_id(
self, url_or_path: str, id: str
) -> ATRIVision:
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
'''
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
""",
(url_or_path, id),
''', (url_or_path, id)
)
res = c.fetchone()
c.close()
if res:
return ATRIVision(*res)
return None
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
return None

View File

@@ -38,13 +38,11 @@ CREATE TABLE IF NOT EXISTS atri_vision(
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';
);

View File

@@ -1,46 +0,0 @@
import abc
from dataclasses import dataclass
@dataclass
class Result:
similarity: float
data: dict
class BaseVecDB:
async def initialize(self):
"""
初始化向量数据库
"""
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
top_k (int): 返回的最相似文档的数量
Returns:
List[Result]: 查询结果
"""
...
@abc.abstractmethod
async def delete(self, doc_id: str) -> bool:
"""
删除指定文档。
Args:
doc_id (str): 要删除的文档 ID
Returns:
bool: 删除是否成功
"""
...

View File

@@ -1,3 +0,0 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]

View File

@@ -1,121 +0,0 @@
import aiosqlite
import os
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
async def get_documents(self, metadata_filters: dict, ids: list = None):
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
"""
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
await self.connection.commit()
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
Returns:
list: A list of user IDs.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
return [row[0] for row in rows]
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
Args:
row (tuple): The row to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": row[0],
"doc_id": row[1],
"text": row[2],
"metadata": row[3],
"created_at": row[4],
"updated_at": row[5],
}
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None

View File

@@ -1,59 +0,0 @@
try:
import faiss
except ModuleNotFoundError:
raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。"
)
import os
import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str = None):
self.dimension = dimension
self.path = path
self.index = None
if path and os.path.exists(path):
self.index = faiss.read_index(path)
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
Args:
vector (np.ndarray): 要插入的向量
id (int): 向量的ID
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.storage[id] = vector
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
"""搜索最相似的向量
Args:
vector (np.ndarray): 查询向量
k (int): 返回的最相似向量的数量
Returns:
tuple: (距离, 索引)
"""
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def save_index(self):
"""保存索引
Args:
path (str): 保存索引的路径
"""
faiss.write_index(self.index, self.path)

View File

@@ -1,17 +0,0 @@
-- 创建文档存储表,包含 faiss 中文档的 id文档文本create_atupdated_at
CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id TEXT NOT NULL,
text TEXT NOT NULL,
metadata TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE documents
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
ALTER TABLE documents
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
CREATE INDEX idx_documents_user_id ON documents(user_id);
CREATE INDEX idx_documents_group_id ON documents(group_id);

View File

@@ -1,117 +0,0 @@
import uuid
import json
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
class FaissVecDB(BaseVecDB):
"""
A class to represent a vector database.
"""
def __init__(
self,
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
self.embedding_provider = embedding_provider
self.document_storage = DocumentStorage(doc_store_path)
self.embedding_storage = EmbeddingStorage(
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def retrieve(
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
metadata_filters (dict): 元数据过滤器
Returns:
List[Result]: 查询结果
"""
embedding = await self.embedding_provider.get_embedding(query)
scores, indices = await self.embedding_storage.search(
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
# TODO: rerank
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
scores[0] = 1.0 - (scores[0] / 2.0)
# NOTE: maybe the size is less than k.
fetched_docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters or {}, ids=indices[0]
)
if not fetched_docs:
return []
result_docs = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
pos = idx_pos.get(indice_idx)
if pos is None:
continue
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
return result_docs[:k]
async def delete(self, doc_id: int):
"""
删除一条文档
"""
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
async def close(self):
await self.document_storage.close()
async def count_documents(self) -> int:
"""
计算文档数量
"""
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0

View File

@@ -1,57 +1,23 @@
"""
事件总线, 用于处理事件的分发和处理
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
class:
EventBus: 事件总线, 用于处理事件的分发和处理
工作流程:
1. 维护一个异步队列, 来接受各种消息事件
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
"""
import asyncio
from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent
class EventBus:
"""事件总线: 用于处理事件的分发和处理
维护一个异步队列, 来接受各种消息事件
"""
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue # 事件队列
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
self.event_queue = event_queue
self.pipeline_scheduler = pipeline_scheduler
async def dispatch(self):
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
logger.info("事件总线已打开。")
while True:
event: AstrMessageEvent = (
await self.event_queue.get()
) # 从事件队列中获取新的事件
self._print_event(event) # 打印日志
asyncio.create_task(
self.pipeline_scheduler.execute(event)
) # 创建新的异步任务来执行管道调度器的处理逻辑
def _print_event(self, event: AstrMessageEvent):
"""用于记录事件信息
Args:
event (AstrMessageEvent): 事件对象
"""
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
event: AstrMessageEvent = await self.event_queue.get()
self._print_event(event)
asyncio.create_task(self.pipeline_scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent):
if event.get_sender_name():
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
logger.info(f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}")
else:
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
)
logger.info(f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}")

View File

@@ -1,68 +0,0 @@
import asyncio
import os
import uuid
import time
class FileTokenService:
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
def __init__(self, default_timeout: float = 300):
self.lock = asyncio.Lock()
self.staged_files = {} # token: (file_path, expire_time)
self.default_timeout = default_timeout
async def _cleanup_expired_tokens(self):
"""清理过期的令牌"""
now = time.time()
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
for token in expired_tokens:
self.staged_files.pop(token, None)
async def register_file(self, file_path: str, timeout: float = None) -> str:
"""向令牌服务注册一个文件。
Args:
file_path(str): 文件路径
timeout(float): 超时时间,单位秒(可选)
Returns:
str: 一个单次令牌
Raises:
FileNotFoundError: 当路径不存在时抛出
"""
async with self.lock:
await self._cleanup_expired_tokens()
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
self.staged_files[file_token] = (file_path, expire_time)
return file_token
async def handle_file(self, file_token: str) -> str:
"""根据令牌获取文件路径,使用后令牌失效。
Args:
file_token(str): 注册时返回的令牌
Returns:
str: 文件路径
Raises:
KeyError: 当令牌不存在或已过期时抛出
FileNotFoundError: 当文件本身已被删除时抛出
"""
async with self.lock:
await self._cleanup_expired_tokens()
if file_token not in self.staged_files:
raise KeyError(f"无效或过期的文件 token: {file_token}")
file_path, _ = self.staged_files.pop(file_token)
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
return file_path

View File

@@ -1,49 +0,0 @@
"""
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
工作流程:
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
2. 运行核心生命周期任务和仪表板服务器
"""
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard
class InitialLoader:
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
self.db = db
self.logger = logger
self.log_broker = log_broker
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
try:
await core_lifecycle.initialize()
except Exception as e:
logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
return
core_task = core_lifecycle.start()
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
)
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
try:
await task # 整个AstrBot在这里运行
except asyncio.CancelledError:
logger.info("🌈 正在关闭 AstrBot...")
await core_lifecycle.stop()

View File

@@ -1,119 +1,37 @@
"""
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
class:
LogBroker: 日志代理类, 用于缓存和分发日志消息
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
LogManager: 日志管理器, 用于创建和配置日志记录器
function:
is_plugin_path: 检查文件路径是否来自插件目录
get_short_level_name: 将日志级别名称转换为四个字母的缩写
工作流程:
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
"""
import logging
import colorlog
import asyncio
import os
import sys
from collections import deque
from asyncio import Queue
from typing import List
# 日志缓存大小
CACHED_SIZE = 200
# 日志颜色配置
log_color_config = {
"DEBUG": "green",
"INFO": "bold_cyan",
"WARNING": "bold_yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
"RESET": "reset",
"asctime": "green",
'DEBUG': 'bold_blue', 'INFO': 'bold_cyan',
'WARNING': 'bold_yellow', 'ERROR': 'red',
'CRITICAL': 'bold_red', 'RESET': 'reset',
'asctime': 'green'
}
def is_plugin_path(pathname):
"""检查文件路径是否来自插件目录
Args:
pathname (str): 文件路径
Returns:
bool: 如果路径来自插件目录,则返回 True否则返回 False
"""
if not pathname:
return False
norm_path = os.path.normpath(pathname)
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
def get_short_level_name(level_name):
"""将日志级别名称转换为四个字母的缩写
Args:
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
Returns:
str: 四个字母的日志级别缩写
"""
level_map = {
"DEBUG": "DBUG",
"INFO": "INFO",
"WARNING": "WARN",
"ERROR": "ERRO",
"CRITICAL": "CRIT",
}
return level_map.get(level_name, level_name[:4].upper())
class LogBroker:
"""日志代理类, 用于缓存和分发日志消息
发布-订阅模式
"""
def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: List[Queue] = [] # 订阅者列表
self.log_cache = deque(maxlen=CACHED_SIZE)
self.subscribers: List[Queue] = []
def register(self) -> Queue:
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns:
Queue: 订阅者的队列, 可用于接收日志消息
"""
'''给每个订阅者返回一个带有日志缓存的队列'''
q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
self.subscribers.append(q)
return q
def unregister(self, q: Queue):
"""取消订阅
Args:
q (Queue): 需要取消订阅的队列
"""
'''取消订阅'''
self.subscribers.remove(q)
def publish(self, log_entry: dict):
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
"""
def publish(self, log_entry: str):
'''发布消息'''
self.log_cache.append(log_entry)
for q in self.subscribers:
try:
@@ -121,126 +39,41 @@ class LogBroker:
except asyncio.QueueFull:
pass
class LogQueueHandler(logging.Handler):
"""日志处理器, 用于将日志消息发送到 LogBroker
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker):
super().__init__()
self.log_broker = log_broker
def emit(self, record):
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish(
{
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
}
)
self.log_broker.publish(log_entry)
class LogManager:
"""日志管理器, 用于创建和配置日志记录器
提供了获取默认日志记录器logger和设置队列处理器的方法
"""
@classmethod
def GetLogger(cls, log_name: str = "default"):
"""获取指定名称的日志记录器logger
Args:
log_name (str): 日志记录器的名称, 默认为 "default"
Returns:
logging.Logger: 返回配置好的日志记录器
"""
def GetLogger(cls, log_name: str = 'default'):
logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
if logger.hasHandlers():
return logger
# 如果logger没有处理器
console_handler = logging.StreamHandler(
sys.stdout
) # 创建一个StreamHandler用于控制台输出
console_handler.setLevel(
logging.DEBUG
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
datefmt="%H:%M:%S",
log_colors=log_color_config,
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s',
datefmt='%H:%M:%S',
log_colors=log_color_config
)
class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record):
record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
)
return True
class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
dirname = os.path.dirname(record.pathname)
record.filename = (
os.path.basename(dirname)
+ "."
+ os.path.basename(record.pathname).replace(".py", "")
)
return True
class LevelNameFilter(logging.Filter):
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称
def filter(self, record):
record.short_levelname = get_short_level_name(record.levelname)
return True
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
logger.addFilter(PluginFilter()) # 添加插件过滤器
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
logger.addHandler(console_handler) # 添加处理器到logger
console_handler.setFormatter(console_formatter)
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
"""
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
if logger.handlers:
handler.setFormatter(logger.handlers[0].formatter)
else:
# 为队列处理器设置相同格式的formatter
handler.setFormatter(
logging.Formatter(
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
)
)
logger.addHandler(handler)
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)

View File

@@ -1,4 +1,4 @@
"""
'''
MIT License
Copyright (c) 2021 Lxns-Network
@@ -20,37 +20,21 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
'''
import asyncio
import base64
import json
import os
import typing as T
import uuid
from enum import Enum
from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
class ComponentType(Enum):
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
Video = "Video" # 视频
At = "At" # At
Node = "Node" # 转发消息的一个节点
Nodes = "Nodes" # 转发消息的多个节点
Poke = "Poke" # QQ 戳一戳
Image = "Image" # 图片
Reply = "Reply" # 回复
Forward = "Forward" # 转发消息
File = "File" # 文件
Plain = "Plain"
Face = "Face"
Record = "Record"
Video = "Video"
At = "At"
RPS = "RPS" # TODO
Dice = "Dice" # TODO
Shake = "Shake" # TODO
@@ -59,14 +43,18 @@ class ComponentType(Enum):
Contact = "Contact" # TODO
Location = "Location" # TODO
Music = "Music"
Image = "Image"
Reply = "Reply"
RedBag = "RedBag"
Poke = "Poke"
Forward = "Forward"
Node = "Node"
Xml = "Xml"
Json = "Json"
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
File = "File"
class BaseMessageComponent(BaseModel):
@@ -81,30 +69,25 @@ class BaseMessageComponent(BaseModel):
k = "type"
if isinstance(v, bool):
v = 1 if v else 0
output += ",%s=%s" % (
k,
str(v)
.replace("&", "&amp;")
.replace(",", "&#44;")
.replace("[", "&#91;")
.replace("]", "&#93;"),
)
output += ",%s=%s" % (k, str(v).replace("&", "&amp;") \
.replace(",", "&#44;") \
.replace("[", "&#91;") \
.replace("]", "&#93;"))
output += "]"
return output
def toDict(self):
data = {}
data = dict()
for k, v in self.__dict__.items():
if k == "type" or v is None:
continue
if k == "_type":
k = "type"
data[k] = v
return {"type": self.type.lower(), "data": data}
async def to_dict(self) -> dict:
# 默认情况下,回退到旧的同步 toDict()
return self.toDict()
return {
"type": self.type.lower(),
"data": data
}
class Plain(BaseMessageComponent):
@@ -118,15 +101,10 @@ class Plain(BaseMessageComponent):
def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本
if not self.convert:
return self.text
return (
self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
)
return self.text.replace("&", "&amp;") \
.replace("[", "&#91;") \
.replace("]", "&#93;")
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent):
type: ComponentType = "Face"
@@ -164,76 +142,6 @@ class Record(BaseMessageComponent):
return Record(file=url, **_)
raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if self.file and self.file.startswith("file:///"):
file_path = self.file[8:]
return file_path
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
file_path = self.file
return os.path.abspath(file_path)
else:
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Returns:
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if self.file and self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
async def register_to_file_service(self) -> str:
"""
将语音注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
class Video(BaseMessageComponent):
type: ComponentType = "Video"
@@ -244,6 +152,9 @@ class Video(BaseMessageComponent):
path: T.Optional[str] = ""
def __init__(self, file: str, **_):
# for k in _.keys():
# if k == "c" and _[k] not in [2, 3]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -256,70 +167,6 @@ class Video(BaseMessageComponent):
return Video(file=url, **_)
raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL则会自动进行下载
Returns:
str: 视频的本地路径,以绝对路径表示。
"""
url = self.file
if url and url.startswith("file:///"):
return url[8:]
elif url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
else:
raise Exception(f"download failed: {url}")
elif os.path.exists(url):
return os.path.abspath(url)
else:
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
"""
将视频注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = self.file
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated video file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "video",
"data": {
"file": payload_file,
},
}
class At(BaseMessageComponent):
type: ComponentType = "At"
@@ -329,12 +176,6 @@ class At(BaseMessageComponent):
def __init__(self, **_):
super().__init__(**_)
def toDict(self):
return {
"type": "at",
"data": {"qq": str(self.qq)},
}
class AtAll(At):
qq: str = "all"
@@ -431,9 +272,13 @@ class Image(BaseMessageComponent):
c: T.Optional[int] = 2
# 额外
path: T.Optional[str] = ""
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: T.Optional[str], **_):
# for k in _.keys():
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
# (k == "c" and _[k] not in [2, 3]):
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -458,100 +303,14 @@ class Image(BaseMessageComponent):
def fromIO(IO):
return Image.fromBytes(IO.read())
async def convert_to_file_path(self) -> str:
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
image_file_path = url[8:]
return image_file_path
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
image_file_path = url
return os.path.abspath(image_file_path)
else:
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str:
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
Returns:
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
else:
raise Exception(f"not a valid file: {url}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
async def register_to_file_service(self) -> str:
"""
将图片注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
id: T.Union[str, int]
"""所引用的消息 ID"""
chain: T.Optional[T.List["BaseMessageComponent"]] = []
"""被引用的消息段列表"""
sender_id: T.Optional[int] | T.Optional[str] = 0
"""被引用的消息对应的发送者的 ID"""
sender_nickname: T.Optional[str] = ""
"""被引用的消息对应的发送者的昵称"""
time: T.Optional[int] = 0
"""被引用的消息发送时间"""
message_str: T.Optional[str] = ""
"""被引用的消息解析后的纯文本消息字符串"""
text: T.Optional[str] = ""
"""deprecated"""
qq: T.Optional[int] = 0
"""deprecated"""
time: T.Optional[int] = 0
seq: T.Optional[int] = 0
"""deprecated"""
def __init__(self, **_):
super().__init__(**_)
@@ -582,92 +341,34 @@ class Forward(BaseMessageComponent):
def __init__(self, **_):
super().__init__(**_)
class Node(BaseMessageComponent):
"""群合并转发消息"""
'''群合并转发消息'''
type: ComponentType = "Node"
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
content: T.Optional[list[BaseMessageComponent]] = []
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0 # 忽略
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[int] = 0 # qq号
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0
def __init__(self, content: list[BaseMessageComponent], **_):
if isinstance(content, Node):
# back
content = [content]
def __init__(self, content: T.Union[str, list], **_):
if isinstance(content, list):
_content = ""
for chain in content:
_content += chain.toString()
content = _content
super().__init__(content=content, **_)
async def to_dict(self):
data_content = []
for comp in self.content:
if isinstance(comp, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await comp.convert_to_base64()
data_content.append(
{
"type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"},
}
)
elif isinstance(comp, Plain):
# For Plain segments, we need to handle the plain differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, File):
# For File segments, we need to handle the file differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, (Node, Nodes)):
# For Node segments, we recursively convert them to dict
d = await comp.to_dict()
data_content.append(d)
else:
d = comp.toDict()
data_content.append(d)
return {
"type": "node",
"data": {
"user_id": str(self.uin),
"nickname": self.name,
"content": data_content,
},
}
class Nodes(BaseMessageComponent):
type: ComponentType = "Nodes"
nodes: T.List[Node]
def __init__(self, nodes: T.List[Node], **_):
super().__init__(nodes=nodes, **_)
def toDict(self):
"""Deprecated. Use to_dict instead"""
ret = {
"messages": [],
}
for node in self.nodes:
d = node.toDict()
ret["messages"].append(d)
return ret
async def to_dict(self):
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
d = await node.to_dict()
ret["messages"].append(d)
return ret
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
class Xml(BaseMessageComponent):
type: ComponentType = "Xml"
data: str
resid: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
@@ -717,149 +418,16 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""
class File(BaseMessageComponent):
"""
文件消息段
"""
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file_: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
def __init__(self, name: str, file: str = "", url: str = ""):
"""文件消息段。"""
super().__init__(name=name, file_=file, url=url)
@property
def file(self) -> str:
"""
获取文件路径如果文件不存在但有URL则同步下载文件
Returns:
str: 文件路径
"""
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段"
)
)
return ""
else:
# 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@file.setter
def file(self, value: str):
"""
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args:
value (str): 文件路径或URL
"""
if value.startswith("http://") or value.startswith("https://"):
self.url = value
else:
self.file_ = value
async def get_file(self, allow_return_url: bool = False) -> str:
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
Args:
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
注意,如果为 True也可能返回文件路径。
Returns:
str: 文件路径或者 http 下载链接
"""
if allow_return_url and self.url:
return self.url
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self):
"""
将文件注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.get_file()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = await self.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "file",
"data": {
"name": self.name,
"file": payload_file,
},
}
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url本地路径
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
@@ -883,12 +451,10 @@ ComponentTypes = {
"poke": Poke,
"forward": Forward,
"node": Node,
"nodes": Nodes,
"xml": Xml,
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown,
"file": File,
"WechatEmoji": WechatEmoji,
'file': File,
}

View File

@@ -1,233 +1,149 @@
import enum
from typing import List, Optional, Union, AsyncGenerator
from typing import List, Optional
from dataclasses import dataclass, field
from astrbot.core.message.components import (
BaseMessageComponent,
Plain,
Image,
At,
AtAll,
)
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
from typing_extensions import deprecated
@dataclass
class MessageChain:
"""MessageChain 描述了一整条消息中带有的所有组件。
class MessageChain():
'''MessageChain 描述了一整条消息中带有的所有组件。
现代消息平台的一条富文本消息中可能由多个组件构成如文本、图片、At 等,并且保留了顺序。
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
"""
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
type: Optional[str] = None
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
use_t2i_: Optional[bool] = None # None 为跟随用户设置
def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。
'''添加一条文本消息到消息链 `chain` 中。
Example:
CommandResult().message("Hello ").message("world!")
# 输出 Hello world!
"""
'''
self.chain.append(Plain(message))
return self
def at(self, name: str, qq: Union[str, int]):
"""添加一条 At 消息到消息链 `chain` 中。
Example:
CommandResult().at("张三", "12345678910")
# 输出 @张三
"""
self.chain.append(At(name=name, qq=qq))
return self
def at_all(self):
"""添加一条 AtAll 消息到消息链 `chain` 中。
Example:
CommandResult().at_all()
# 输出 @所有人
"""
self.chain.append(AtAll())
return self
@deprecated("请使用 message 方法代替。")
def error(self, message: str):
"""添加一条错误消息到消息链 `chain` 中
'''添加一条错误消息到消息链 `chain` 中
Example:
CommandResult().error("解析失败")
"""
'''
self.chain.append(Plain(message))
return self
def url_image(self, url: str):
"""添加一条图片消息https 链接)到消息链 `chain` 中。
'''添加一条图片消息https 链接)到消息链 `chain` 中。
Note:
如果需要发送本地图片,请使用 `file_image` 方法。
Example:
CommandResult().image("https://example.com/image.jpg")
"""
'''
self.chain.append(Image.fromURL(url))
return self
def file_image(self, path: str):
"""添加一条图片消息(本地文件路径)到消息链 `chain` 中。
'''添加一条图片消息(本地文件路径)到消息链 `chain` 中。
Note:
如果需要发送网络图片,请使用 `url_image` 方法。
CommandResult().image("image.jpg")
"""
'''
self.chain.append(Image.fromFileSystem(path))
return self
def base64_image(self, base64_str: str):
"""添加一条图片消息base64 编码字符串)到消息链 `chain` 中。
Example:
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
"""
self.chain.append(Image.fromBase64(base64_str))
return self
def use_t2i(self, use_t2i: bool):
"""设置是否使用文本转图片服务。
'''设置是否使用文本转图片服务。
Args:
use_t2i (bool): 是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
"""
'''
self.use_t2i_ = use_t2i
return self
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
def squash_plain(self):
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
if not self.chain:
return
new_chain = []
first_plain = None
plain_texts = []
for comp in self.chain:
if isinstance(comp, Plain):
if first_plain is None:
first_plain = comp
new_chain.append(comp)
plain_texts.append(comp.text)
else:
new_chain.append(comp)
if first_plain is not None:
first_plain.text = "".join(plain_texts)
self.chain = new_chain
return self
class EventResultType(enum.Enum):
"""用于描述事件处理的结果类型。
'''用于描述事件处理的结果类型。
Attributes:
CONTINUE: 事件将会继续传播
STOP: 事件将会终止传播
"""
'''
CONTINUE = enum.auto()
STOP = enum.auto()
class ResultContentType(enum.Enum):
"""用于描述事件结果的内容的类型。"""
'''用于描述事件结果的内容的类型。
'''
LLM_RESULT = enum.auto()
"""调用 LLM 产生的结果"""
'''调用 LLM 产生的结果'''
GENERAL_RESULT = enum.auto()
"""普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
STREAMING_FINISH = enum.auto()
"""流式输出完成"""
'''普通的消息结果'''
@dataclass
class MessageEventResult(MessageChain):
"""MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
现代消息平台的一条富文本消息中可能由多个组件构成如文本、图片、At 等,并且保留了顺序。
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`result_type` (EventResultType): 事件处理的结果类型。
"""
result_type: Optional[EventResultType] = field(
default_factory=lambda: EventResultType.CONTINUE
)
result_content_type: Optional[ResultContentType] = field(
default_factory=lambda: ResultContentType.GENERAL_RESULT
)
async_stream: Optional[AsyncGenerator] = None
"""异步流"""
def stop_event(self) -> "MessageEventResult":
"""终止事件传播。"""
'''
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
def stop_event(self) -> 'MessageEventResult':
'''终止事件传播。
'''
self.result_type = EventResultType.STOP
return self
def continue_event(self) -> "MessageEventResult":
"""继续事件传播。"""
def continue_event(self) -> 'MessageEventResult':
'''继续事件传播。
'''
self.result_type = EventResultType.CONTINUE
return self
def is_stopped(self) -> bool:
"""
'''
是否终止事件传播。
"""
'''
return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
"""设置异步流。"""
self.async_stream = stream
return self
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
"""设置事件处理的结果类型。
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
result_type (EventResultType): 事件处理的结果类型。
"""
'''
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
"""是否为 LLM 结果。"""
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
# 为了兼容旧版代码,保留 CommandResult 的别名
CommandResult = MessageEventResult
def get_plain_text(self) -> str:
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
'''
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
CommandResult = MessageEventResult

View File

@@ -1,44 +1,34 @@
from astrbot.core.message.message_event_result import (
EventResultType,
MessageEventResult,
)
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
"PreProcessStage", # 处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
"RespondStage", # 发送消息
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
]
__all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PlatformCompatibilityStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
"MessageEventResult",
"EventResultType",
]
"EventResultType"
]

View File

@@ -6,32 +6,26 @@ from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from .strategies.strategy import StrategySelector
@register_stage
class ContentSafetyCheckStage(Stage):
"""检查内容安全
'''检查内容安全
当前只会检查文本的。
"""
'''
async def initialize(self, ctx: PipelineContext):
config = ctx.astrbot_config["content_safety"]
config = ctx.astrbot_config['content_safety']
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]:
'''检查内容安全'''
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
if not ok:
if event.is_at_or_wake_command:
event.set_result(
MessageEventResult().message(
"你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"
)
)
yield
event.set_result(MessageEventResult().message("你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"))
yield
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return
event.continue_event()

View File

@@ -1,8 +1,8 @@
import abc
from typing import Tuple
class ContentSafetyStrategy(abc.ABC):
@abc.abstractmethod
def check(self, content: str) -> Tuple[bool, str]:
raise NotImplementedError
raise NotImplementedError

View File

@@ -1,30 +1,30 @@
"""
'''
使用此功能应该先 pip install baidu-aip
"""
'''
from . import ContentSafetyStrategy
from aip import AipContentCensor
class BaiduAipStrategy(ContentSafetyStrategy):
def __init__(self, appid: str, ak: str, sk: str) -> None:
self.app_id = appid
self.api_key = ak
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
self.client = AipContentCensor(self.app_id,
self.api_key,
self.secret_key)
def check(self, content: str):
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
if 'conclusionType' not in res:
return False, ""
if res["conclusionType"] == 1:
if res['conclusionType'] == 1:
return True, ""
else:
if "data" not in res:
if 'data' not in res:
return False, ""
count = len(res["data"])
count = len(res['data'])
info = f"百度审核服务发现 {count} 处违规:\n"
for i in res["data"]:
for i in res['data']:
info += f"{i['msg']}\n"
info += "\n判断结果:" + res["conclusion"]
return False, info
info += "\n判断结果:"+res['conclusion']
return False, info

View File

@@ -1,23 +1,23 @@
import re
import os
import json
import base64
from . import ContentSafetyStrategy
class KeywordsStrategy(ContentSafetyStrategy):
def __init__(self, extra_keywords: list) -> None:
self.keywords = []
if extra_keywords is None:
extra_keywords = []
self.keywords.extend(extra_keywords)
# keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words")
keywords_path = os.path.join(os.path.dirname(__file__), 'unfit_words')
# internal keywords
# if os.path.exists(keywords_path):
# with open(keywords_path, "r", encoding="utf-8") as f:
# self.keywords.extend(
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
if os.path.exists(keywords_path):
with open(keywords_path, "r", encoding="utf-8") as f:
self.keywords.extend(json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords'])
def check(self, content: str) -> bool:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"
return True, ""
return True, ""

View File

@@ -2,7 +2,6 @@ from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger
class StrategySelector:
def __init__(self, config: dict) -> None:
self.enabled_strategies: List[ContentSafetyStrategy] = []

View File

@@ -0,0 +1 @@
ewogICAgImtleXdvcmRzIjogWwogICAgICAgICLkuaDov5HlubMiLAogICAgICAgICLog6HplKbmtpsiLAogICAgICAgICLmsZ/ms73msJEiLAogICAgICAgICLmuKnlrrblrp0iLAogICAgICAgICLmnY7lhYvlvLoiLAogICAgICAgICLmnY7plb/mmKUiLAogICAgICAgICLmr5vms73kuJwiLAogICAgICAgICLpgpPlsI/lubMiLAogICAgICAgICLlkajmganmnaUiLAogICAgICAgICLnpL7kvJrkuLvkuYkiLAogICAgICAgICLlhbHkuqflhZoiLAogICAgICAgICLlhbHkuqfkuLvkuYkiLAogICAgICAgICLlpKfpmYblrpjmlrkiLAogICAgICAgICLljJfkuqzmlL/mnYMiLAogICAgICAgICLkuK3ljY7luJ3lm70iLAogICAgICAgICLkuK3lm73mlL/lupwiLAogICAgICAgICLlhbHni5ciLAogICAgICAgICLlha3lm5vkuovku7YiLAogICAgICAgICLlpKnlronpl6giLAogICAgICAgICLlha3lm5siLAogICAgICAgICLmlL/msrvlsYDluLjlp5QiLAogICAgICAgICLlrabmva4iLAogICAgICAgICLlhavkuZ0iLAogICAgICAgICLkuozljYHlpKciLAogICAgICAgICLmsJHov5vlhZoiLAogICAgICAgICLlj7Dni6wiLAogICAgICAgICLlj7Dmub7ni6znq4siLAogICAgICAgICLlj7Dmub7lm70iLAogICAgICAgICLlm73msJHlhZoiLAogICAgICAgICLlj7Dmub7msJHlm70iLAogICAgICAgICLkuK3ljY7msJHlm70iLAogICAgICAgICJwb3JuaHViIiwKICAgICAgICAiUG9ybmh1YiIsCiAgICAgICAgIuS9nOeIsSIsCiAgICAgICAgIuWBmueIsSIsCiAgICAgICAgIuaAp+S6pCIsCiAgICAgICAgIuiHquaFsCIsCiAgICAgICAgIumYtOiMjiIsCiAgICAgICAgIua3q+WmhyIsCiAgICAgICAgIuiCm+S6pCIsCiAgICAgICAgIuS6pOmFjSIsCiAgICAgICAgIuaAp+WFs+ezuyIsCiAgICAgICAgIuaAp+a0u+WKqCIsCiAgICAgICAgIuiJsuaDhSIsCiAgICAgICAgIuiJsuWbviIsCiAgICAgICAgIuijuOS9kyIsCiAgICAgICAgIuWwj+eptCIsCiAgICAgICAgIua3q+iNoSIsCiAgICAgICAgIuaAp+eIsSIsCiAgICAgICAgIua4r+eLrCIsCiAgICAgICAgIuazlei9ruWKnyIsCiAgICAgICAgIuWFreWbmyIKICAgIF0KfQ==

View File

@@ -1,113 +1,8 @@
import inspect
import traceback
import typing as T
from dataclasses import dataclass
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star import PluginManager
from astrbot.api import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
@dataclass
class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
async def call_event_hook(
self,
event: AstrMessageEvent,
hook_type: EventType,
*args,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
"""
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()
async def call_handler(
self,
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
astrbot_config: AstrBotConfig
plugin_manager: PluginManager

View File

@@ -1,56 +0,0 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
@register_stage
class PlatformCompatibilityStage(Stage):
"""检查所有处理器的平台兼容性。
这个阶段会检查所有处理器是否在当前平台启用如果未启用则设置platform_compatible属性为False。
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化平台兼容性检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 获取当前平台ID
platform_id = event.get_platform_id()
# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
activated_handlers = []
# 标记不兼容的处理器
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
handler.platform_compatible = False
else:
# 确保处理器为平台兼容状态
handler.platform_compatible = True
# 更新已激活的处理器列表
event.set_extra("activated_handlers", activated_handlers)

View File

@@ -7,67 +7,64 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Record, Image
@register_stage
class PreProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
self.platform_settings: dict = self.config.get('platform_settings', {})
self.stt_settings: dict = self.config.get("provider_stt_settings", {})
self.platform_settings: dict = self.config.get("platform_settings", {})
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""在处理事件之前的预处理"""
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''在处理事件之前的预处理'''
# 路径映射
if mappings := self.platform_settings.get("path_mapping", []):
if mappings := self.platform_settings.get('path_mapping', []):
# 支持 RecordImage 消息段的路径映射。
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, (Record, Image)) and component.url:
for mapping in mappings:
from_, to_ = mapping.split(":")
from_ = from_.removesuffix("/")
to_ = to_.removesuffix("/")
url = component.url.removeprefix("file://")
if url.startswith(from_):
component.url = url.replace(from_, to_, 1)
logger.debug(f"路径映射: {url} -> {component.url}")
message_chain[idx] = component
# STT
if self.stt_settings.get("enable", False):
if self.stt_settings.get('enable', False):
# TODO: 独立
ctx = self.plugin_manager.context
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break

View File

@@ -1,57 +0,0 @@
import abc
import typing as T
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
from enum import Enum, auto
class AgentState(Enum):
"""Agent 状态枚举"""
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData
class BaseAgentRunner:
@abc.abstractmethod
async def reset(self) -> None:
"""
Reset the agent to its initial state.
This method should be called before starting a new run.
"""
...
@abc.abstractmethod
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""
Process a single step of the agent.
"""
...
@abc.abstractmethod
def done(self) -> bool:
"""
Check if the agent has completed its task.
Returns True if the agent is done, False otherwise.
"""
...
@abc.abstractmethod
def get_final_llm_resp(self) -> LLMResponse | None:
"""
Get the final observation from the agent.
This method should be called after the agent is done.
"""
...

View File

@@ -1,306 +0,0 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
from ...context import PipelineContext
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageChain,
)
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from mcp.types import (
TextContent,
ImageContent,
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# TODO:
# 1. 处理平台不兼容的处理器
class ToolLoopAgent(BaseAgentRunner):
def __init__(
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
) -> None:
self.provider = provider
self.req = None
self.event = event
self.pipeline_ctx = pipeline_ctx
self._state = AgentState.IDLE
self.final_llm_resp = None
self.streaming = False
@override
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
self.req = req
self.streaming = streaming
self.final_llm_resp = None
self._state = AgentState.IDLE
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__)
async for resp in stream: # type: ignore
yield resp
else:
yield await self.provider.text_chat(**self.req.__dict__)
@override
async def step(self):
"""
Process a single step of the agent.
This method should return the result of the step.
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text)
),
)
continue
llm_resp_result = llm_response
break # got final response
if not llm_resp_result:
return
# 处理 LLM 响应
llm_resp = llm_resp_result
if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.ERROR)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
)
),
)
if not llm_resp.tools_call_name:
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# 执行事件钩子
if await self.pipeline_ctx.call_event_hook(
self.event, EventType.OnLLMResponseEvent, llm_resp
):
return
# 返回 LLM 结果
if llm_resp.result_chain:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=llm_resp.result_chain),
)
elif llm_resp.completion_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text)
),
)
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
yield AgentResponse(
type="tool_call_result",
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
)
self.req.append_tool_calls_result(tool_calls_result)
async def _handle_function_tools(
self,
req: ProviderRequest,
llm_response: LLMResponse,
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
"""处理函数工具调用。"""
tool_call_result_blocks: list[ToolCallMessageSegment] = []
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
# 执行函数调用
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if not res:
continue
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
else:
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
# 尝试调用工具函数
wrapper = self.pipeline_ctx.call_handler(
self.event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None:
# Tool 返回结果
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
yield MessageChain().message(resp)
else:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
self.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
# 处理函数调用响应
if tool_call_result_blocks:
yield tool_call_result_blocks
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -1,369 +1,159 @@
"""
'''
本地 Agent 模式的 LLM 调用 Stage
"""
import asyncio
import copy
import json
'''
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..agent_runner.tool_loop_agent import ToolLoopAgent
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
self.provider_wake_prefix: str = settings["wake_prefix"] # str
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 10)
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。"
)
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
if sel_provider and isinstance(sel_provider, str):
provider = _ctx.get_provider_by_id(sel_provider)
if not provider:
logger.error(f"未找到指定的提供商: {sel_provider}")
return provider
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider is None:
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
req.prompt = event.message_str[len(self.provider_wake_prefix):]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(
event.unified_msg_origin
)
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
if not conversation:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
req.session_id = event.unified_msg_origin
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# 执行请求 LLM 前事件钩子。
if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(req.contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
req.contexts = req.contexts[index:]
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# fix messages
req.contexts = self.fix_messages(req.contexts)
# Call Agent
tool_loop_agent = ToolLoopAgent(
provider=provider,
event=event,
pipeline_ctx=self.ctx,
)
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
)
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
async def requesting():
step_idx = 0
while step_idx < self.max_step:
step_idx += 1
try:
logger.debug(f"提供商请求 Payload: {req}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
async for resp in tool_loop_agent.step():
if event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if (
self.show_tool_use
or event.get_platform_name() == "webchat"
):
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
continue
if not self.streaming_response:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if tool_loop_agent.done():
break
except Exception as e:
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if self.streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting())
)
yield
if tool_loop_agent.done():
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
else:
chain = final_llm_resp.result_chain.chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
)
)
else:
async for _ in requesting():
yield
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
if conversation and not req.conversation.title:
messages = json.loads(conversation.history)
latest_pair = messages[-2:]
if not latest_pair:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
prompt=(
f"Please summarize the following query of user:\n"
f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
),
)
if llm_resp and llm_resp.completion_text:
logger.debug(
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
)
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
event.unified_msg_origin, title=title
)
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
# webchat adapter 中session_id 的格式是 f"webchat!{username}!{cid}"
# TODO: 优化 WebChat 适配器的对话管理
if event.session_id:
username, cid = event.session_id.split("!")[1:3]
db_helper = self.ctx.plugin_manager.context._db
db_helper.update_conversation_title(
user_id=username,
cid=cid,
title=title,
)
async def _save_to_history(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse | None,
):
if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
return
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
messages.append(await req.assemble_context())
# 这一轮对话的 LLM 响应
if req.tool_calls_result:
if not isinstance(req.tool_calls_result, list):
messages.extend(req.tool_calls_result.to_openai_messages())
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=messages
)
def fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'err':
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"))
elif llm_response.role == 'tool':
# function calling
function_calling_result = {}
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
# 尝试调用工具函数
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
async for resp in wrapper:
if resp is not None: # 有 return 返回
function_calling_result[func_tool_name] = resp
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
if function_calling_result:
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
# 我们重新执行一遍这个 stage
req.func_tool = None # 暂时不支持递归工具调用
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
for tool_name, tool_result in function_calling_result.items():
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
req.prompt += extra_prompt
async for _ in self.process(event, _nested=True):
yield
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
if llm_response.completion_text:
event.set_result(MessageEventResult().message(llm_response.completion_text))
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
return
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {
"role": "user",
"content": req.prompt
}
contexts.append(new_record)
contexts.append({
"role": "assistant",
"content": llm_response.completion_text
})
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=contexts_to_save
)

View File

@@ -1,7 +1,6 @@
"""
'''
本地 Agent 模式的 AstrBot 插件调用 Stage
"""
'''
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
@@ -12,56 +11,39 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.star.star import star_map
import traceback
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
"activated_handlers"
)
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra(
"handlers_parsed_params"
)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra("handlers_parsed_params")
if not handlers_parsed_params:
handlers_parsed_params = {}
for handler in activated_handlers:
# 检查处理器是否在当前平台兼容
if (
hasattr(handler, "platform_compatible")
and handler.platform_compatible is False
):
logger.debug(
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
)
continue
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
wrapper = self.ctx.call_handler(event, handler.handler, **params)
logger.debug(f"执行插件 handler {handler.handler_full_name}")
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果
event.clear_result() # 清除上一个 handler 的结果
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
event.stop_event()
event.stop_event()

View File

@@ -5,29 +5,26 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core import logger
@register_stage
class ProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.llm_request_sub_stage = LLMRequestSubStage()
await self.llm_request_sub_stage.initialize(ctx)
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件"""
activated_handlers: List[StarHandlerMetadata] = event.get_extra(
"activated_handlers"
)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
# 有插件 Handler 被激活
if activated_handlers:
async for resp in self.star_request_sub_stage.process(event):
@@ -43,26 +40,20 @@ class ProcessStage(Stage):
yield
else:
yield
# 调用 LLM 相关请求
if not self.ctx.astrbot_config["provider_settings"].get("enable", True):
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if (
not event._has_send_oper
and event.is_at_or_wake_command
and not event.call_llm
):
if not event._has_send_oper and event.is_at_or_wake_command:
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
event.get_result() and not event.get_result().is_stopped()
) or not event.get_result():
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
async for _ in self.llm_request_sub_stage.process(event):
yield
yield

View File

@@ -5,6 +5,7 @@ from typing import DefaultDict, Deque, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.config.astrbot_config import RateLimitStrategy
@@ -31,19 +32,11 @@ class RateLimitStage(Stage):
"""
初始化限流器,根据配置设置限流参数。
"""
self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
"count"
]
self.rate_limit_time = timedelta(
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"]
)
self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
"strategy"
] # stall or discard
self.rate_limit_count = ctx.astrbot_config['platform_settings']['rate_limit']['count']
self.rate_limit_time = timedelta(seconds=ctx.astrbot_config['platform_settings']['rate_limit']['time'])
self.rl_strategy = ctx.astrbot_config['platform_settings']['rate_limit']['strategy'] # stall or discard
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
"""
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
@@ -58,34 +51,31 @@ class RateLimitStage(Stage):
now = datetime.now()
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
# 检查并处理限流,可能需要多次检查直到满足条件
while True:
timestamps = self.event_timestamps[session_id]
self._remove_expired_timestamps(timestamps, now)
timestamps = self.event_timestamps[session_id]
if len(timestamps) < self.rate_limit_count:
timestamps.append(now)
break
else:
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds() + 0.3
self._remove_expired_timestamps(timestamps, now)
match self.rl_strategy:
case RateLimitStrategy.STALL.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
)
await asyncio.sleep(stall_duration)
now = datetime.now()
case RateLimitStrategy.DISCARD.value:
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
)
return event.stop_event()
if len(timestamps) >= self.rate_limit_count:
# 达到限流阈值,计算下一个窗口的时间
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds()
match self.rl_strategy:
case RateLimitStrategy.STALL.value:
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
await asyncio.sleep(stall_duration)
case RateLimitStrategy.DISCARD.value:
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。")
return event.stop_event()
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
def _remove_expired_timestamps(
self, timestamps: Deque[datetime], now: datetime
) -> None:
timestamps.append(now)
return event.continue_event()
def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None:
"""
移除时间窗口外的时间戳。
@@ -95,4 +85,4 @@ class RateLimitStage(Stage):
"""
expiry_threshold: datetime = now - self.rate_limit_time
while timestamps and timestamps[0] < expiry_threshold:
timestamps.popleft()
timestamps.popleft()

View File

@@ -2,89 +2,50 @@ import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
from astrbot.core.message.components import Plain, Reply, At
@register_stage
class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(
comp.text and comp.text.strip()
), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
Comp.Video: lambda comp: bool(comp.file), # 视频
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
}
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.config = ctx.astrbot_config
self.platform_settings: dict = self.config.get("platform_settings", {})
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
"reply_with_mention"
]
self.reply_with_quote = ctx.astrbot_config["platform_settings"][
"reply_with_quote"
]
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
# 分段回复
self.enable_seg: bool = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["enable"]
self.only_llm_result = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["only_llm_result"]
self.interval_method = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["interval_method"]
self.log_base = float(
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"]
)
interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][
"interval"
]
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.interval_method = ctx.astrbot_config['platform_settings']['segmented_reply']['interval_method']
self.log_base = float(ctx.astrbot_config['platform_settings']['segmented_reply']['log_base'])
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f"解析分段回复的间隔时间失败。{e}")
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
logger.info(f"分段回复间隔时间:{self.interval}")
async def _word_cnt(self, text: str) -> int:
"""分段回复 统计字数"""
'''分段回复 统计字数'''
if all(ord(c) < 128 for c in text):
word_count = len(text.split())
else:
word_count = len([c for c in text if c.isalnum()])
return word_count
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
"""分段回复 计算间隔时间"""
if self.interval_method == "log":
if isinstance(comp, Comp.Plain):
'''分段回复 计算间隔时间'''
if self.interval_method == 'log':
if isinstance(comp, Plain):
wc = await self._word_cnt(comp.text)
i = math.log(wc + 1, self.log_base)
return random.uniform(i, i + 0.5)
@@ -94,143 +55,43 @@ class RespondStage(Stage):
# random
return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
"""检查消息链是否为空
Args:
chain (list[BaseMessageComponent]): 包含消息对象的列表
"""
if not chain:
return True
for comp in chain:
comp_type = type(comp)
# 检查组件类型是否在字典中
if comp_type in self._component_validators:
if self._component_validators[comp_type](comp):
return False
# 如果所有组件都为空
return True
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
if result.result_content_type == ResultContentType.STREAMING_RESULT:
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_name()})")
await event.send_streaming(result.async_stream, use_fallback)
return
elif len(result.chain) > 0:
# 检查路径映射
if mappings := self.platform_settings.get("path_mapping", []):
for idx, component in enumerate(result.chain):
if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
# 检查消息链是否为空
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
non_record_comps = [
c for c in result.chain if not isinstance(c, Comp.Record)
]
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
)
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
if len(result.chain) > 0:
await event._pre_send()
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, Comp.At):
if isinstance(comp, At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Comp.Reply):
if isinstance(comp, Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
await event.send(MessageChain([*decorated_comps, comp]))
else:
for rcomp in record_comps:
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
try:
await event.send(MessageChain(non_record_comps))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
)
await event.send(result)
await event._post_send()
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
try:
logger.debug(
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
event.clear_result()
event.clear_result()

View File

@@ -1,138 +1,77 @@
import re
import time
import re
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage, registered_stages
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@register_stage
class ResultDecorateStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"]
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
"reply_with_mention"
]
self.reply_with_quote = ctx.astrbot_config["platform_settings"][
"reply_with_quote"
]
self.t2i_word_threshold = ctx.astrbot_config["t2i_word_threshold"]
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.t2i_word_threshold = ctx.astrbot_config['t2i_word_threshold']
try:
self.t2i_word_threshold = int(self.t2i_word_threshold)
if self.t2i_word_threshold < 50:
self.t2i_word_threshold = 50
except BaseException:
self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
self.t2i_use_network = self.t2i_strategy == "remote"
self.forward_threshold = ctx.astrbot_config["platform_settings"][
"forward_threshold"
]
self.forward_threshold = ctx.astrbot_config['platform_settings']['forward_threshold']
# 分段回复
self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][
"words_count_threshold"
]
)
self.enable_segmented_reply = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["enable"]
self.only_llm_result = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["only_llm_result"]
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["content_cleanup_rule"]
self.words_count_threshold = int(ctx.astrbot_config['platform_settings']['segmented_reply']['words_count_threshold'])
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
self.content_cleanup_rule = ctx.astrbot_config['platform_settings']['segmented_reply']['content_cleanup_rule']
# exception
self.content_safe_check_reply = ctx.astrbot_config["content_safety"][
"also_use_in_response"
]
self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response']
self.content_safe_check_stage = None
if self.content_safe_check_reply:
for stage in registered_stages:
if stage.__class__.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None or not result.chain:
if result is None:
return
if result.result_content_type == ResultContentType.STREAMING_RESULT:
return
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
# 回复时检查内容安全
if (
self.content_safe_check_reply
and self.content_safe_check_stage
and result.is_llm_result()
and not is_stream # 流式输出不检查内容安全
):
if self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result():
text = ""
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(
event, check_text=text
):
async for _ in self.content_safe_check_stage.process(event, check_text=text):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
)
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
try:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
if is_stream:
logger.warning(
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
)
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。"
)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
# 流式输出不执行下面的逻辑
if is_stream:
logger.info("流式输出已启用,跳过结果装饰阶段")
return
# 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result()
if result is None:
return
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
@@ -140,16 +79,10 @@ class ResultDecorateStage(Stage):
if isinstance(comp, Plain):
comp.text = self.reply_prefix + comp.text
break
# 分段回复
if self.enable_segmented_reply and event.get_platform_name() not in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
if (
self.only_llm_result and result.is_llm_result()
) or not self.only_llm_result:
# 分段回复
if self.enable_segmented_reply:
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
@@ -157,9 +90,9 @@ class ResultDecorateStage(Stage):
# 不分段回复
new_chain.append(comp)
continue
split_response = re.findall(
self.regex, comp.text, re.DOTALL | re.MULTILINE
)
split_response = []
for line in comp.text.split("\n"):
split_response.extend(re.findall(self.regex, line))
if not split_response:
new_chain.append(comp)
continue
@@ -172,70 +105,32 @@ class ResultDecorateStage(Stage):
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
result.chain = new_chain
# TTS
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
event.unified_msg_origin
)
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
and SessionServiceManager.should_process_tts_request(event)
):
if self.ctx.astrbot_config['provider_tts_settings']['enable'] and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
logger.info("TTS 请求: " + comp.text)
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
else:
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
new_chain.append(comp)
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
except BaseException:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
# 文本转图片
elif (
result.use_t2i_ is None and self.ctx.astrbot_config["t2i"]
) or result.use_t2i_:
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
@@ -245,56 +140,40 @@ class ResultDecorateStage(Stage):
if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time()
try:
url = await html_renderer.render_t2i(
plain_str, return_url=True, use_network=self.t2i_use_network
)
url = await html_renderer.render_t2i(plain_str, return_url=True)
except BaseException:
logger.error("文本转图片失败,使用文本发送。")
return
if time.time() - render_start > 3:
logger.warning(
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
)
logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
if url:
if url.startswith("http"):
result.chain = [Image.fromURL(url)]
elif (
self.ctx.astrbot_config["t2i_use_file_service"]
and self.ctx.astrbot_config["callback_api_base"]
):
token = await file_token_service.register_file(url)
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
logger.debug(f"已注册:{url}")
result.chain = [Image.fromURL(url)]
else:
result.chain = [Image.fromFileSystem(url)]
result.chain = [Image.fromURL(url)]
# 触发转发消息
has_forwarded = False
if event.get_platform_name() == "aiocqhttp":
if event.get_platform_name() == 'aiocqhttp':
word_cnt = 0
for comp in result.chain:
if isinstance(comp, Plain):
word_cnt += len(comp.text)
if word_cnt > self.forward_threshold:
node = Node(
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
uin=event.get_self_id(),
name="AstrBot",
content=[
*result.chain
]
)
result.chain = [node]
has_forwarded = True
if not has_forwarded:
# at 回复
if (
self.reply_with_mention
and event.get_message_type() != MessageType.FRIEND_MESSAGE
):
result.chain.insert(
0, At(qq=event.get_sender_id(), name=event.get_sender_name())
)
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
if len(result.chain) > 1 and isinstance(result.chain[1], Plain):
result.chain[1].text = "\n" + result.chain[1].text
# 引用回复
if self.reply_with_quote:
if not any(isinstance(item, File) for item in result.chain):

View File

@@ -5,75 +5,44 @@ from typing import AsyncGenerator
from astrbot.core.platform import AstrMessageEvent
from astrbot.core import logger
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
class PipelineScheduler():
def __init__(self, context: PipelineContext):
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__ .__name__))
self.ctx = context
async def initialize(self):
"""初始化管道调度器时, 初始化所有阶段"""
for stage in registered_stages:
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
"""依次执行各个阶段
Args:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i] # 获取当前要执行的阶段
stage = registered_stages[i]
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coroutine = stage.process(
event
) # 调用阶段的process方法, 返回协程或者异步生成器
if isinstance(coroutine, AsyncGenerator):
# 如果返回的是异步生成器, 实现洋葱模型的核心
async for _ in coroutine:
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
coro = stage.process(event)
if isinstance(coro, AsyncGenerator):
async for _ in coro:
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
break
# 递归调用, 处理所有后续阶段
await self._process_stages(event, i + 1)
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
break
else:
# 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
# 简单地等待它执行完成, 然后继续执行下一个阶段
await coroutine
await coro
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
break
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
break
async def execute(self, event: AstrMessageEvent):
"""执行 pipeline
Args:
event (AstrMessageEvent): 事件对象
"""
'''执行 pipeline'''
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() == "webchat":
if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")
logger.debug("pipeline 执行完毕。")

View File

@@ -1,22 +0,0 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core import logger
@register_stage
class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
event.stop_event()

View File

@@ -1,39 +1,68 @@
from __future__ import annotations
import abc
from typing import List, AsyncGenerator, Union
import inspect
from astrbot.api import logger
from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
registered_stages: List[Stage] = []
'''维护了所有已注册的 Stage 实现类'''
def register_stage(cls):
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
'''一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类
'''
registered_stages.append(cls())
return cls
class Stage(abc.ABC):
"""描述一个 Pipeline 的某个阶段"""
'''描述一个 Pipeline 的某个阶段
'''
@abc.abstractmethod
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
'''初始化阶段
'''
raise NotImplementedError
@abc.abstractmethod
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件
Args:
event (AstrMessageEvent): 事件对象,包含事件的相关信息
Returns:
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
"""
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
raise NotImplementedError
async def _call_handler(
self,
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
'''调用 Handler。'''
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as e:
# 向下兼容
logger.debug(str(e))
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
async for ret in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None无返回值
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -1,17 +1,13 @@
from typing import AsyncGenerator, Union
from astrbot import logger
from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At, Reply
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
@register_stage
class WakingCheckStage(Stage):
@@ -25,41 +21,18 @@ class WakingCheckStage(Stage):
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化唤醒检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True
)
# 私聊是否需要 wake_prefix 才能唤醒机器人
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
"platform_settings"
].get("friend_message_needs_wake_prefix", False)
# 是否忽略机器人自己发送的消息
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
"ignore_bot_self_message", False
)
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
"ignore_at_all", False
)
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
if (
self.ignore_bot_self_message
and event.get_self_id() == event.get_sender_id()
):
# 忽略机器人自己发送的消息
event.stop_event()
return
# 设置 sender 身份
event.message_str = event.message_str.strip()
for admin_id in self.ctx.astrbot_config["admins_id"]:
if str(event.get_sender_id()) == admin_id:
if event.get_sender_id() == admin_id:
event.role = "admin"
break
@@ -83,18 +56,11 @@ class WakingCheckStage(Stage):
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
if not is_wake:
# 检查是否有at消息 / at全体成员消息 / 引用了bot的消息
# 检查是否有 at 消息
for message in messages:
if (
(
isinstance(message, At)
and (str(message.qq) == str(event.get_self_id()))
)
or (isinstance(message, AtAll) and not self.ignore_at_all)
or (
isinstance(message, Reply)
and str(message.sender_id) == str(event.get_self_id())
)
if isinstance(message, At) and (
str(message.qq) == str(event.get_self_id())
or str(message.qq) == "all"
):
is_wake = True
event.is_wake = True
@@ -102,7 +68,7 @@ class WakingCheckStage(Stage):
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat() and not self.friend_message_needs_wake_prefix:
if event.is_private_chat():
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
@@ -111,14 +77,11 @@ class WakingCheckStage(Stage):
# 检查插件的 handler filter
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry.get_handlers_by_event_type(
EventType.AdapterMessageEvent
):
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
# filter 需满足 AND 逻辑关系
passed = True
passed = True
permission_not_pass = False
permission_filter_raise_error = False
if len(handler.event_filters) == 0:
continue
@@ -127,7 +90,6 @@ class WakingCheckStage(Stage):
if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
permission_filter_raise_error = filter.raise_error
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
@@ -143,21 +105,11 @@ class WakingCheckStage(Stage):
break
if passed:
if permission_not_pass:
if not permission_filter_raise_error:
# 跳过
continue
if self.no_permission_reply:
await event.send(
MessageChain().message(
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
)
)
logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
)
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"))
event.stop_event()
return
is_wake = True
event.is_wake = True
@@ -166,13 +118,8 @@ class WakingCheckStage(Stage):
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
event, activated_handlers
)
event.clear_extra()
event.set_extra("activated_handlers", activated_handlers)
event.set_extra("handlers_parsed_params", handlers_parsed_params)

View File

@@ -5,61 +5,38 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
@register_stage
class WhitelistCheckStage(Stage):
"""检查是否在群聊/私聊白名单"""
'''检查是否在群聊/私聊白名单
'''
async def initialize(self, ctx: PipelineContext) -> None:
self.enable_whitelist_check = ctx.astrbot_config["platform_settings"][
"enable_id_white_list"
]
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
self.whitelist = [
str(i).strip() for i in self.whitelist if str(i).strip() != ""
]
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
"wl_ignore_admin_on_group"
]
self.wl_ignore_admin_on_friend = ctx.astrbot_config["platform_settings"][
"wl_ignore_admin_on_friend"
]
self.wl_log = ctx.astrbot_config["platform_settings"]["id_whitelist_log"]
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
# 白名单检查未启用
return
if len(self.whitelist) == 0:
# 白名单为空,不检查
return
if event.get_platform_name() == "webchat":
if event.get_platform_name() == 'webchat':
# WebChat 豁免
return
# 检查是否在白名单
if self.wl_ignore_admin_on_group:
if (
event.role == "admin"
and event.get_message_type() == MessageType.GROUP_MESSAGE
):
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
return
if self.wl_ignore_admin_on_friend:
if (
event.role == "admin"
and event.get_message_type() == MessageType.FRIEND_MESSAGE
):
if event.role == 'admin' and event.get_message_type() == MessageType.FRIEND_MESSAGE:
return
if (
event.unified_msg_origin not in self.whitelist
and str(event.get_group_id()).strip() not in self.whitelist
):
if event.unified_msg_origin not in self.whitelist:
if self.wl_log:
logger.info(
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"
)
event.stop_event()
logger.info(f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。")
event.stop_event()

View File

@@ -1,14 +1,4 @@
from .platform import Platform
from .astr_message_event import AstrMessageEvent
from .platform_metadata import PlatformMetadata
from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group
__all__ = [
"Platform",
"AstrMessageEvent",
"PlatformMetadata",
"AstrBotMessage",
"MessageMember",
"MessageType",
"Group",
]
from .astrbot_message import AstrBotMessage, MessageMember, MessageType

View File

@@ -1,98 +1,76 @@
import abc
import asyncio
import re
import hashlib
import uuid
from dataclasses import dataclass
from typing import List, Union, Optional, AsyncGenerator
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
Plain,
Image,
BaseMessageComponent,
Face,
At,
AtAll,
Forward,
Reply,
)
from .astrbot_message import AstrBotMessage
from .platform_metadata import PlatformMetadata
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.entities import ProviderRequest
from typing import List, Union
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.db.po import Conversation
@dataclass
class MessageSesion:
platform_name: str
message_type: MessageType
session_id: str
def __str__(self):
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
@staticmethod
def from_str(session_str: str):
platform_name, message_type, session_id = session_str.split(":")
return MessageSesion(platform_name, MessageType(message_type), session_id)
class AstrMessageEvent(abc.ABC):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
):
def __init__(self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,):
self.message_str = message_str
"""纯文本的消息"""
'''纯文本的消息'''
self.message_obj = message_obj
"""消息对象, AstrBotMessage。带有完整的消息结构。"""
'''消息对象AstrBotMessage。带有完整的消息结构。'''
self.platform_meta = platform_meta
"""消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp"""
'''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp'''
self.session_id = session_id
"""用户的会话 ID。可以直接使用下面的 unified_msg_origin"""
'''用户的会话 ID。可以直接使用下面的 unified_msg_origin'''
self.role = "member"
"""用户是否是管理员。如果是管理员,这里是 admin"""
self.is_wake = False
"""是否唤醒(是否通过 WakingStage)"""
'''用户是否是管理员。如果是管理员,这里是 admin'''
self.is_wake = False # 是否通过 WakingStage
'''是否唤醒'''
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
'''是否是 At 机器人或者带有唤醒词或者是私聊事件监听器会让 is_wake 设为 True但是不会让这个属性置为 True'''
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
message_type=message_obj.type,
session_id=session_id,
session_id=session_id
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
'''统一的消息来源字符串。格式为 platform_name:message_type:session_id'''
self._result: MessageEventResult = None
"""消息事件的结果"""
self._has_send_oper = False
"""在此次事件中是否有过至少一次发送消息的操作"""
self.call_llm = False
"""是否在此消息事件中禁止默认的 LLM 请求"""
'''消息事件的结果'''
self._has_send_oper = False
'''是否有过至少一次发送操作'''
# back_compability
self.platform = platform_meta
def get_platform_name(self):
return self.platform_meta.name
def get_platform_id(self):
return self.platform_meta.id
def get_message_str(self) -> str:
"""
'''
获取消息字符串。
"""
'''
return self.message_str
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
outline = ""
for i in chain:
@@ -109,212 +87,183 @@ class AstrMessageEvent(abc.ABC):
elif isinstance(i, Forward):
# 转发消息
outline += "[转发消息]"
elif isinstance(i, Reply):
# 引用回复
if i.message_str:
outline += f"[引用消息({i.sender_nickname}: {i.message_str})]"
else:
outline += "[引用消息]"
else:
outline += f"[{i.type}]"
outline += " "
return outline
def get_message_outline(self) -> str:
"""
'''
获取消息概要。
除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。
"""
'''
return self._outline_chain(self.message_obj.message)
def get_messages(self) -> List[BaseMessageComponent]:
"""
'''
获取消息链。
"""
'''
return self.message_obj.message
def get_message_type(self) -> MessageType:
"""
'''
获取消息类型。
"""
'''
return self.message_obj.type
def get_session_id(self) -> str:
"""
'''
获取会话id。
"""
'''
return self.session_id
def get_group_id(self) -> str:
"""
'''
获取群组id。如果不是群组消息返回空字符串。
"""
'''
return self.message_obj.group_id
def get_self_id(self) -> str:
"""
'''
获取机器人自身的id。
"""
'''
return self.message_obj.self_id
def get_sender_id(self) -> str:
"""
'''
获取消息发送者的id。
"""
'''
return self.message_obj.sender.user_id
def get_sender_name(self) -> str:
"""
'''
获取消息发送者的名称。(可能会返回空字符串)
"""
'''
return self.message_obj.sender.nickname
def set_extra(self, key, value):
"""
'''
设置额外的信息。
"""
'''
self._extras[key] = value
def get_extra(self, key=None):
"""
def get_extra(self, key = None):
'''
获取额外的信息。
"""
'''
if key is None:
return self._extras
return self._extras.get(key, None)
def clear_extra(self):
"""
'''
清除额外的信息。
"""
'''
self._extras.clear()
def is_private_chat(self) -> bool:
"""
'''
是否是私聊。
"""
'''
return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value
def is_wake_up(self) -> bool:
"""
'''
是否是唤醒机器人的事件。
"""
'''
return self.is_wake
def is_admin(self) -> bool:
"""
'''
是否是管理员。
"""
'''
return self.role == "admin"
async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
"""
将消息缓冲区中的文本按指定正则表达式分割后发送消息平台作为不支持流式输出平台的Fallback
"""
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(1.5) # 限速
return buffer
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊。
Fallback仅支持 aiocqhttp。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
)
async def send(self, message: MessageChain):
'''
发送消息到消息平台。
'''
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
'''调度器会在执行 send() 前调用该方法'''
pass
async def _post_send(self):
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
'''调度器会在执行 send() 后调用该方法'''
pass
def set_result(self, result: Union[MessageEventResult, str]):
"""设置消息事件的结果。
'''设置消息事件的结果。
Note:
事件处理器可以通过设置结果来控制事件是否继续传播,并向消息适配器发送消息。
如果没有设置 `MessageEventResult` 中的 result_type默认为 CONTINUE。即事件将会继续向后面的 listener 或者 command 传播。
Example:
```
async def ban_handler(self, event: AstrMessageEvent):
if event.get_sender_id() in self.blacklist:
event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP)
return
async def check_count(self, event: AstrMessageEvent):
self.count += 1
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
return
```
"""
'''
if isinstance(result, str):
result = MessageEventResult().message(result)
self._result = result
def stop_event(self):
"""终止事件传播。"""
'''终止事件传播。
'''
if self._result is None:
self.set_result(MessageEventResult().stop_event())
else:
self._result.stop_event()
def continue_event(self):
"""继续事件传播。"""
'''继续事件传播。
'''
if self._result is None:
self.set_result(MessageEventResult().continue_event())
else:
self._result.continue_event()
def is_stopped(self) -> bool:
"""
'''
是否终止事件传播。
"""
'''
if self._result is None:
return False # 默认是继续传播
return self._result.is_stopped()
def should_call_llm(self, call_llm: bool):
"""
是否在此消息事件中禁止默认的 LLM 请求。
只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。
"""
self.call_llm = call_llm
return False # 默认是继续传播
return self._result.is_stopped()
def get_result(self) -> MessageEventResult:
"""
'''
获取消息事件的结果。
"""
'''
return self._result
def clear_result(self):
"""
'''
清除消息事件的结果。
"""
'''
self._result = None
"""消息链相关"""
'''消息链相关'''
def make_result(self) -> MessageEventResult:
"""
'''
创建一个空的消息事件结果。
Example:
```python
# 纯文本回复
yield event.make_result().message("Hi")
@@ -322,103 +271,65 @@ class AstrMessageEvent(abc.ABC):
yield event.make_result().url_image("https://example.com/image.jpg")
yield event.make_result().file_image("image.jpg")
```
"""
'''
return MessageEventResult()
def plain_result(self, text: str) -> MessageEventResult:
"""
'''
创建一个空的消息事件结果,只包含一条文本消息。
"""
'''
return MessageEventResult().message(text)
def image_result(self, url_or_path: str) -> MessageEventResult:
"""
'''
创建一个空的消息事件结果,只包含一条图片消息。
根据开头是否包含 http 来判断是网络图片还是本地图片。
"""
'''
if url_or_path.startswith("http"):
return MessageEventResult().url_image(url_or_path)
return MessageEventResult().file_image(url_or_path)
def chain_result(self, chain: List[BaseMessageComponent]) -> MessageEventResult:
"""
'''
创建一个空的消息事件结果,包含指定的消息链。
"""
'''
mer = MessageEventResult()
mer.chain = chain
return mer
"""LLM 请求相关"""
'''LLM 请求相关'''
def request_llm(
self,
prompt: str,
func_tool_manager=None,
func_tool_manager = None,
session_id: str = None,
image_urls: List[str] = [],
contexts: List = [],
system_prompt: str = "",
conversation: Conversation = None,
conversation: Conversation = None
) -> ProviderRequest:
"""
'''
创建一个 LLM 请求。
Examples:
```py
yield event.request_llm(prompt="hi")
```
prompt: 提示词
system_prompt: 系统提示词
session_id: 已经过时,留空即可
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。如果同时传入了 conversation将会忽略 conversation。
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。
"""
if len(contexts) > 0 and conversation:
conversation = None
'''
return ProviderRequest(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool_manager,
contexts=contexts,
system_prompt=system_prompt,
conversation=conversation,
)
"""平台适配器"""
async def send(self, message: MessageChain):
"""发送消息到消息平台。
Args:
message (MessageChain): 消息链,具体使用方式请参考文档。
"""
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
sid = str(uuid.UUID(bytes=hash_obj.digest()))
asyncio.create_task(
Metric.upload(
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
)
)
self._has_send_oper = True
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息返回当前群聊的数据。
适配情况:
- aiocqhttp(OneBotv11)
"""
...
prompt = prompt,
session_id = session_id,
image_urls = image_urls,
func_tool = func_tool_manager,
contexts = contexts,
system_prompt = system_prompt,
conversation=conversation
)

View File

@@ -4,64 +4,26 @@ from dataclasses import dataclass
from astrbot.core.message.components import BaseMessageComponent
from .message_type import MessageType
@dataclass
class MessageMember:
class MessageMember():
user_id: str # 发送者id
nickname: str = None
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"User ID: {self.user_id},"
f"Nickname: {self.nickname if self.nickname else 'N/A'}"
)
@dataclass
class Group:
group_id: str
"""群号"""
group_name: str = None
"""群名称"""
group_avatar: str = None
"""群头像"""
group_owner: str = None
"""群主 id"""
group_admins: List[str] = None
"""群管理员 id"""
members: List[MessageMember] = None
"""所有群成员"""
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"Group ID: {self.group_id}\n"
f"Name: {self.group_name if self.group_name else 'N/A'}\n"
f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n"
f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n"
f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n"
f"Members Len: {len(self.members) if self.members else 0}\n"
f"First Member: {self.members[0] if self.members else 'N/A'}\n"
)
class AstrBotMessage:
"""
'''
AstrBot 的消息对象
"""
'''
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group_id: str = "" # 群组id如果为私聊则为空
group_id: str = "" # 群组id如果为私聊则为空
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
raw_message: object
timestamp: int # 消息时间戳
def __init__(self) -> None:
self.timestamp = int(time.time())

View File

@@ -1,5 +1,3 @@
import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .platform import Platform
from typing import List
@@ -8,157 +6,48 @@ from .register import platform_cls_map
from astrbot.core import logger
from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager:
class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue):
self.platform_insts: List[Platform] = []
"""加载的 Platform 的实例"""
self._inst_map = {}
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
'''加载的 Platform 的实例'''
self.platforms_config = config['platform']
self.settings = config['platform_settings']
self.event_queue = event_queue
try:
for platform in self.platforms_config:
if not platform['enable']:
continue
match platform['type']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
except Exception as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}")
async def initialize(self):
"""初始化所有平台适配器"""
for platform in self.platforms_config:
try:
await self.load_platform(platform)
except Exception as e:
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
# 网页聊天
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
asyncio.create_task(
self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat"))
)
async def load_platform(self, platform_config: dict):
"""实例化一个平台"""
# 动态导入
try:
if not platform_config["enable"]:
return
logger.info(
f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ..."
)
match platform_config["type"]:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import (
AiocqhttpAdapter, # noqa: F401
)
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import (
QQOfficialPlatformAdapter, # noqa: F401
)
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import (
QQOfficialWebhookPlatformAdapter, # noqa: F401
)
case "wechatpadpro":
from .sources.wechatpadpro.wechatpadpro_adapter import (
WeChatPadProAdapter, # noqa: F401
)
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram":
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
case "wecom":
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa
)
case "discord":
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
)
except Exception as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}")
if platform_config["type"] not in platform_cls_map:
logger.error(
f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误"
)
return
cls_type = platform_cls_map[platform_config["type"]]
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config["id"]] = {
"inst": inst,
"client_id": inst.client_self_id,
}
self.platform_insts.append(inst)
asyncio.create_task(
self._task_wrapper(
asyncio.create_task(
inst.run(),
name=f"platform_{platform_config['type']}_{platform_config['id']}",
)
)
)
async def _task_wrapper(self, task: asyncio.Task):
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def reload(self, platform_config: dict):
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
await self.load_platform(platform_config)
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
await self.terminate_platform(key)
async def terminate_platform(self, platform_id: str):
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id, None)
client_id = info["client_id"]
inst = info["inst"]
try:
self.platform_insts.remove(
next(
inst
for inst in self.platform_insts
if inst.client_self_id == client_id
)
)
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
if getattr(inst, "terminate", None):
await inst.terminate()
async def terminate(self):
for inst in self.platform_insts:
if getattr(inst, "terminate", None):
await inst.terminate()
if not platform['enable']:
continue
if platform['type'] not in platform_cls_map:
logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = platform_cls_map[platform['type']]
logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
def get_insts(self):
return self.platform_insts
return self.platform_insts

View File

@@ -1,7 +1,6 @@
from enum import Enum
class MessageType(Enum):
GROUP_MESSAGE = "GroupMessage" # 群组形式的消息
FRIEND_MESSAGE = "FriendMessage" # 私聊、好友等单聊消息
OTHER_MESSAGE = "OtherMessage" # 其他类型的消息,如系统消息等
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
OTHER_MESSAGE = 'OtherMessage' # 其他类型的消息,如系统消息等

View File

@@ -1,5 +1,4 @@
import abc
import uuid
from typing import Awaitable, Any
from asyncio import Queue
from .platform_metadata import PlatformMetadata
@@ -8,52 +7,36 @@ from astrbot.core.message.message_event_result import MessageChain
from .astr_message_event import MessageSesion
from astrbot.core.utils.metrics import Metric
class Platform(abc.ABC):
def __init__(self, event_queue: Queue):
super().__init__()
# 维护了消息平台的事件队列EventBus 会从这里取出事件并处理。
self._event_queue = event_queue
self.client_self_id = uuid.uuid4().hex
@abc.abstractmethod
def run(self) -> Awaitable[Any]:
"""
'''
得到一个平台的运行实例,需要返回一个协程对象。
"""
'''
raise NotImplementedError
async def terminate(self):
"""
终止一个平台的运行实例。
"""
...
@abc.abstractmethod
def meta(self) -> PlatformMetadata:
"""
'''
得到一个平台的元数据。
"""
'''
raise NotImplementedError
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
) -> Awaitable[Any]:
"""
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain) -> Awaitable[Any]:
'''
通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法。
"""
await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name)
'''
await Metric.upload(msg_event_tick = 1, adapter_name = self.meta().name)
def commit_event(self, event: AstrMessageEvent):
"""
'''
提交一个事件到事件队列。
"""
self._event_queue.put_nowait(event)
def get_client(self):
"""
获取平台的客户端对象。
"""
pass
'''
self._event_queue.put_nowait(event)

View File

@@ -1,16 +1,12 @@
from dataclasses import dataclass
@dataclass
class PlatformMetadata:
class PlatformMetadata():
name: str
"""平台的名称"""
'''平台的名称'''
description: str
"""平台的描述"""
id: str = None
"""平台的唯一标识符,用于配置中识别特定平台"""
'''平台的描述'''
default_config_tmpl: dict = None
"""平台的默认配置模板"""
'''平台的默认配置模板'''
adapter_display_name: str = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''

View File

@@ -3,46 +3,42 @@ from .platform_metadata import PlatformMetadata
from astrbot.core import logger
platform_registry: List[PlatformMetadata] = []
"""维护了通过装饰器注册的平台适配器"""
'''维护了通过装饰器注册的平台适配器'''
platform_cls_map: Dict[str, Type] = {}
"""维护了平台适配器名称和适配器类的映射"""
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(
adapter_name: str,
desc: str,
adapter_name: str,
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None,
adapter_display_name: str = None
):
"""用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
"""
'''用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
'''
def decorator(cls):
if adapter_name in platform_cls_map:
raise ValueError(
f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。"
)
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
# 添加必备选项
if default_config_tmpl:
if "type" not in default_config_tmpl:
default_config_tmpl["type"] = adapter_name
if "enable" not in default_config_tmpl:
default_config_tmpl["enable"] = False
if "id" not in default_config_tmpl:
default_config_tmpl["id"] = adapter_name
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = adapter_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
if 'id' not in default_config_tmpl:
default_config_tmpl['id'] = adapter_name
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
adapter_display_name=adapter_display_name
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
logger.debug(f"平台适配器 {adapter_name} 已注册")
return cls
return decorator

View File

@@ -1,217 +1,60 @@
import asyncio
import re
from typing import AsyncGenerator, Dict, List
from aiocqhttp import CQHttp, Event
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
Node,
Nodes,
Plain,
Record,
Video,
File,
BaseMessageComponent,
)
from astrbot.api.platform import Group, MessageMember
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record, At, Node, Music, Video
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
class AiocqhttpMessageEvent(AstrMessageEvent):
def __init__(
self, message_str, message_obj, platform_meta, session_id, bot: CQHttp
):
def __init__(self, message_str, message_obj, platform_meta, session_id, bot: CQHttp):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@staticmethod
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
"""修复部分字段"""
if isinstance(segment, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await segment.convert_to_base64()
return {
"type": segment.type.lower(),
"data": {
"file": f"base64://{bs64}",
},
}
elif isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
return d
elif isinstance(segment, Video):
d = await segment.to_dict()
return d
else:
# For other segments, we simply convert them to a dict by calling toDict
return segment.toDict()
@staticmethod
async def _parse_onebot_json(message_chain: MessageChain):
"""解析成 OneBot json 格式"""
'''解析成 OneBot json 格式'''
ret = []
for segment in message_chain.chain:
d = segment.toDict()
if isinstance(segment, Plain):
if not segment.text.strip():
continue
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
d['type'] = 'text'
elif isinstance(segment, (Image, Record)):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
bs64_data = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
bs64_data = file_to_base64(image_file_path)
elif segment.file and segment.file.startswith("base64://"):
bs64_data = segment.file
else:
bs64_data = file_to_base64(segment.file)
d['data'] = {
'file': bs64_data,
}
elif isinstance(segment, At):
d['data'] = {
'qq': str(segment.qq) # 转换为字符串
}
ret.append(d)
return ret
@classmethod
async def _dispatch_send(
cls,
bot: CQHttp,
event: Event | None,
is_group: bool,
session_id: str,
messages: list[dict],
):
if event:
await bot.send(event=event, message=messages)
elif is_group:
await bot.send_group_msg(group_id=session_id, message=messages)
else:
await bot.send_private_msg(user_id=session_id, message=messages)
@classmethod
async def send_message(
cls,
bot: CQHttp,
message_chain: MessageChain,
event: Event | None = None,
is_group: bool = False,
session_id: str = None,
):
"""发送消息"""
# 转发消息、文件消息不能和普通消息混在一起发送
send_one_by_one = any(
isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
)
if not send_one_by_one:
ret = await cls._parse_onebot_json(message_chain)
if not ret:
return
await cls._dispatch_send(bot, event, is_group, session_id, ret)
return
for seg in message_chain.chain:
if isinstance(seg, (Node, Nodes)):
# 合并转发消息
if isinstance(seg, Node):
nodes = Nodes([seg])
seg = nodes
payload = await seg.to_dict()
if is_group:
payload["group_id"] = session_id
await bot.call_action("send_group_forward_msg", **payload)
else:
payload["user_id"] = session_id
await bot.call_action("send_private_forward_msg", **payload)
elif isinstance(seg, File):
d = await cls._from_segment_to_dict(seg)
await cls._dispatch_send(bot, event, is_group, session_id, [d])
else:
messages = await cls._parse_onebot_json(MessageChain([seg]))
if not messages:
continue
await cls._dispatch_send(bot, event, is_group, session_id, messages)
await asyncio.sleep(0.5)
async def send(self, message: MessageChain):
"""发送消息"""
event = self.message_obj.raw_message
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
is_group = False
if self.get_group_id():
is_group = True
session_id = self.get_group_id()
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
send_one_by_one = False
for seg in message.chain:
if isinstance(seg, (Node, Music)):
# 转发消息不能和普通消息混在一起发送
send_one_by_one = True
break
if send_one_by_one:
for seg in message.chain:
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
await asyncio.sleep(0.5)
else:
session_id = self.get_sender_id()
await self.send_message(
bot=self.bot,
message_chain=message,
event=event,
is_group=is_group,
session_id=session_id,
)
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)
async def get_group(self, group_id=None, **kwargs):
if isinstance(group_id, str) and group_id.isdigit():
group_id = int(group_id)
elif self.get_group_id():
group_id = int(self.get_group_id())
else:
return None
info: dict = await self.bot.call_action(
"get_group_info",
group_id=group_id,
)
members: List[Dict] = await self.bot.call_action(
"get_group_member_list",
group_id=group_id,
)
owner_id = None
admin_ids = []
for member in members:
if member["role"] == "owner":
owner_id = member["user_id"]
if member["role"] == "admin":
admin_ids.append(member["user_id"])
group = Group(
group_id=str(group_id),
group_name=info.get("group_name"),
group_avatar="",
group_admins=admin_ids,
group_owner=str(owner_id),
members=[
MessageMember(
user_id=member["user_id"],
nickname=member.get("nickname") or member.get("card"),
)
for member in members
],
)
return group
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)

View File

@@ -1,17 +1,11 @@
import os
import time
import asyncio
import logging
import uuid
import itertools
from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
@@ -20,189 +14,122 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
from astrbot.core.utils.io import download_file
@register_platform_adapter(
"aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。"
)
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.host = platform_config["ws_reverse_host"]
self.port = platform_config["ws_reverse_port"]
self.unique_session = platform_settings['unique_session']
self.host = platform_config['ws_reverse_host']
self.port = platform_config['ws_reverse_port']
self.metadata = PlatformMetadata(
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
)
self.bot = CQHttp(
use_ws_reverse=True,
import_name="aiocqhttp",
api_timeout_sec=180,
access_token=platform_config.get(
"ws_reverse_token"
), # 以防旧版本配置不存在
)
@self.bot.on_request()
async def request(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_notice()
async def notice(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message("group")
async def group(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message("private")
async def private(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_websocket_connection
def on_websocket_connection(_):
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
is_group = session.message_type == MessageType.GROUP_MESSAGE
if is_group:
session_id = session.session_id.split("_")[-1]
else:
session_id = session.session_id
await AiocqhttpMessageEvent.send_message(
bot=self.bot,
message_chain=message_chain,
event=None, # 这里不需要 event因为是通过 session 发送的
is_group=is_group,
session_id=session_id,
"aiocqhttp",
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value:
case MessageType.GROUP_MESSAGE.value:
if "_" in session.session_id:
# 独立会话
_, group_id = session.session_id.split("_")
await self.bot.send_group_msg(group_id=group_id, message=ret)
else:
await self.bot.send_group_msg(group_id=session.session_id, message=ret)
case MessageType.FRIEND_MESSAGE.value:
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
async def convert_message(self, event: Event) -> AstrBotMessage:
logger.debug(f"[aiocqhttp] RawMessage {event}")
if event["post_type"] == "message":
if event['post_type'] == 'message':
abm = await self._convert_handle_message_event(event)
if abm.sender.user_id == "2854196310":
# 屏蔽 QQ 管家的消息
return
elif event["post_type"] == "notice":
elif event['post_type'] == 'notice':
abm = await self._convert_handle_notice_event(event)
elif event["post_type"] == "request":
elif event['post_type'] == 'request':
abm = await self._convert_handle_request_event(event)
return abm
async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage:
"""OneBot V11 请求类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.type = MessageType.OTHER_MESSAGE
if "group_id" in event and event["group_id"]:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_str = ""
abm.message = []
abm.timestamp = int(time.time())
abm.message_id = uuid.uuid4().hex
abm.raw_message = event
return abm
async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:
"""OneBot V11 通知类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.type = MessageType.OTHER_MESSAGE
if "group_id" in event and event["group_id"]:
abm.group_id = str(event.group_id)
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = (
str(abm.sender.user_id) + "_" + str(event.group_id)
) # 也保留群组 id
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_str = ""
abm.message = []
abm.raw_message = event
abm.timestamp = int(time.time())
abm.message_id = uuid.uuid4().hex
if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
return abm
async def _convert_handle_message_event(
self, event: Event, get_reply=True
) -> AstrBotMessage:
"""OneBot V11 消息类事件
@param event: 事件对象
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
'''OneBot V11 请求类事件'''
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
str(event.sender["user_id"]), event.sender["nickname"]
user_id=event.user_id,
nickname=event.user_id
)
if event["message_type"] == "group":
abm.type = MessageType.OTHER_MESSAGE
if 'group_id' in event and event['group_id']:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
elif event["message_type"] == "private":
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = (
abm.sender.user_id + "_" + str(event.group_id)
) # 也保留群组 id
abm.session_id = abm.sender.user_id + "_" + str(event.group_id)
abm.message_str = ''
abm.message = []
abm.timestamp = int(time.time())
abm.message_id = uuid.uuid4().hex
abm.raw_message = event
return abm
async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:
'''OneBot V11 通知类事件'''
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
user_id=event.user_id,
nickname=event.user_id
)
abm.type = MessageType.OTHER_MESSAGE
if 'group_id' in event and event['group_id']:
abm.group_id = str(event.group_id)
abm.type = MessageType.GROUP_MESSAGE
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id
else:
abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id
abm.message_str = ""
abm.message = []
abm.raw_message = event
abm.timestamp = int(time.time())
abm.message_id = uuid.uuid4().hex
if 'sub_type' in event:
if event['sub_type'] == 'poke' and 'target_id' in event:
abm.message.append(Poke(qq=str(event['target_id']), type='poke')) # noqa: F405
return abm
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
'''OneBot V11 消息类事件'''
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname'])
if event['message_type'] == 'group':
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
elif event['message_type'] == 'private':
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id
else:
abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id
abm.message_id = str(event.message_id)
abm.message = []
message_str = ""
if not isinstance(event.message, list):
err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp请将其配置文件中的 message.post-format 更改为 array。"
@@ -212,181 +139,113 @@ class AiocqhttpAdapter(Platform):
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return
# 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
for m in event.message:
t = m['type']
a = None
if t == "text":
current_text = "".join(m["data"]["text"] for m in m_group).strip()
if not current_text:
# 如果文本段为空,则跳过
continue
message_str += current_text
a = ComponentTypes[t](text=current_text) # noqa: F405
abm.message.append(a)
elif t == "file":
for m in m_group:
if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m["data"].get("file_name", "file")
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:
# Napcat
ret = None
if abm.type == MessageType.GROUP_MESSAGE:
ret = await self.bot.call_action(
action="get_group_file_url",
file_id=event.message[0]["data"]["file_id"],
group_id=event.group_id,
)
elif abm.type == MessageType.FRIEND_MESSAGE:
ret = await self.bot.call_action(
action="get_private_file_url",
file_id=event.message[0]["data"]["file_id"],
)
if ret and "url" in ret:
file_url = ret["url"] # https
a = File(name="", url=file_url)
abm.message.append(a)
else:
logger.error(f"获取文件失败: {ret}")
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
elif t == "reply":
for m in m_group:
if not get_reply:
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
else:
try:
reply_event_data = await self.bot.call_action(
action="get_msg",
message_id=int(m["data"]["id"]),
)
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
reply_event_data["post_type"] = "message"
new_event = Event.from_payload(reply_event_data)
if not new_event:
logger.error(
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
)
continue
abm_reply = await self._convert_handle_message_event(
new_event, get_reply=False
)
reply_seg = Reply(
id=abm_reply.message_id,
chain=abm_reply.message,
sender_id=abm_reply.sender.user_id,
sender_nickname=abm_reply.sender.nickname,
time=abm_reply.timestamp,
message_str=abm_reply.message_str,
text=abm_reply.message_str, # for compatibility
qq=abm_reply.sender.user_id, # for compatibility
)
abm.message.append(reply_seg)
except BaseException as e:
logger.error(f"获取引用消息失败: {e}")
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
elif t == "at":
first_at_self_processed = False
for m in m_group:
if t == 'text':
message_str += m['data']['text'].strip()
elif t == 'file':
if m['data']['url'] and m['data']['url'].startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m['data'].get('file_name', "file")
path = os.path.join("data/temp", file_name)
await download_file(m['data']['url'], path)
m['data'] = {
"file": path,
"name": file_name
}
else:
try:
if m["data"]["qq"] == "all":
abm.message.append(At(qq="all", name="全体成员"))
continue
at_info = await self.bot.call_action(
action="get_stranger_info",
user_id=int(m["data"]["qq"]),
)
if at_info:
nickname = at_info.get("nick", "") or at_info.get(
"nickname", ""
)
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
abm.message.append(
At(
qq=m["data"]["qq"],
name=nickname,
)
)
if is_at_self and not first_at_self_processed:
# 第一个@是机器人不添加到message_str
first_at_self_processed = True
else:
# 非第一个@机器人或@其他用户添加到message_str
message_str += f" @{nickname}({m['data']['qq']}) "
else:
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
else:
for m in m_group:
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = event
return abm
def run(self) -> Awaitable[Any]:
if not self.host or not self.port:
logger.warning(
"aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port将使用默认值http://0.0.0.0:6199"
)
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port将使用默认值http://0.0.0.0:6199")
self.host = "0.0.0.0"
self.port = 6199
coro = self.bot.run_task(
host=self.host,
port=int(self.port),
shutdown_trigger=self.shutdown_trigger_placeholder,
)
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_request()
async def request(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_notice()
async def notice(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('group')
async def group(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_websocket_connection
def on_websocket_connection(_):
logger.info("aiocqhttp 适配器已连接。")
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
self.shutdown_event = asyncio.Event()
return coro
async def terminate(self):
self.shutdown_event.set()
async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭")
logging.getLogger('aiocqhttp').setLevel(logging.ERROR)
return bot
def meta(self) -> PlatformMetadata:
return self.metadata
async def shutdown_trigger_placeholder(self):
while not self._event_queue.closed:
await asyncio.sleep(1)
logger.info("aiocqhttp 适配器已关闭。")
async def handle_msg(self, message: AstrBotMessage):
message_event = AiocqhttpMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
bot=self.bot,
bot=self.bot
)
self.commit_event(message_event)
def get_client(self) -> CQHttp:
return self.bot
self.commit_event(message_event)

View File

@@ -1,231 +0,0 @@
import asyncio
import os
import uuid
import aiohttp
import dingtalk_stream
import threading
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .dingtalk_event import DingtalkMessageEvent
from ...register import register_platform_adapter
from astrbot import logger
from dingtalk_stream import AckMessage
from astrbot.core.utils.io import download_file
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class MyEventHandler(dingtalk_stream.EventHandler):
async def process(self, event: dingtalk_stream.EventMessage):
print(
"2",
event.headers.event_type,
event.headers.event_id,
event.headers.event_born_time,
event.data,
)
return AckMessage.STATUS_OK, "OK"
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
class DingtalkPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.unique_session = platform_settings["unique_session"]
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im)
await self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK"
self.client = AstrCallbackClient()
credential = dingtalk_stream.Credential(self.client_id, self.client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential, logger=logger)
client.register_all_event_handler(MyEventHandler())
client.register_callback_handler(
dingtalk_stream.ChatbotMessage.TOPIC, self.client
)
self.client_ = client # 用于 websockets 的 client
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
)
async def convert_msg(
self, message: dingtalk_stream.ChatbotMessage
) -> AstrBotMessage:
abm = AstrBotMessage()
abm.message = []
abm.message_str = ""
abm.timestamp = int(message.create_at / 1000)
abm.type = (
MessageType.GROUP_MESSAGE
if message.conversation_type == "2"
else MessageType.FRIEND_MESSAGE
)
abm.sender = MessageMember(
user_id=message.sender_id, nickname=message.sender_nick
)
abm.self_id = message.chatbot_user_id
abm.message_id = message.message_id
abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE:
if message.is_in_at_list:
abm.message.append(At(qq=abm.self_id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id
else:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
message_type: str = message.message_type
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
contents: list[dict] = rtc.rich_text_list
for content in contents:
plains = ""
if "text" in content:
plains += content["text"]
abm.message.append(Plain(plains))
elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file(
content["downloadCode"],
message.robot_code,
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
case "audio":
pass
return abm # 别忘了返回转换后的消息对象
async def download_ding_file(
self, download_code: str, robot_code: str, ext: str
) -> str:
"""下载钉钉文件
:param access_token: 钉钉机器人的 access_token
:param download_code: 下载码
:param robot_code: 机器人码
:param ext: 文件后缀
:return: 文件路径
"""
access_token = await self.get_access_token()
headers = {
"x-acs-dingtalk-access-token": access_token,
}
payload = {
"downloadCode": download_code,
"robotCode": robot_code,
}
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
headers=headers,
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}"
)
return None
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
return f_path
async def get_access_token(self) -> str:
payload = {
"appKey": self.client_id,
"appSecret": self.client_secret,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}"
)
return None
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
event = DingtalkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
client=self.client,
)
self._event_queue.put_nowait(event)
async def run(self):
# await self.client_.start()
# 钉钉的 SDK 并没有实现真正的异步start() 里面有堵塞方法。
def start_client(loop: asyncio.AbstractEventLoop):
try:
self._shutdown_event = threading.Event()
task = loop.create_task(self.client_.start())
self._shutdown_event.wait()
if task.done():
task.result()
except Exception as e:
if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭")
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self):
def monkey_patch_close():
raise Exception("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
self._shutdown_event.set()
def get_client(self):
return self.client

View File

@@ -1,75 +0,0 @@
import asyncio
import dingtalk_stream
import astrbot.api.message_components as Comp
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot import logger
class DingtalkMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str,
message_obj,
platform_meta,
session_id,
client: dingtalk_stream.ChatbotHandler,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
async def send_with_client(
self, client: dingtalk_stream.ChatbotHandler, message: MessageChain
):
for segment in message.chain:
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"AstrBot",
segment.text,
self.message_obj.raw_message,
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
try:
if not segment.file:
logger.warning("钉钉图片 segment 缺少 file 字段,跳过")
continue
if segment.file.startswith(("http://", "https://")):
image_url = segment.file
else:
image_url = await segment.register_to_file_service()
markdown_str = f"![image]({image_url})\n\n"
ret = await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"😄",
markdown_str,
self.message_obj.raw_message,
)
logger.debug(f"send image: {ret}")
except Exception as e:
logger.error(f"钉钉图片处理失败: {e}")
logger.warning(f"跳过图片发送: {image_path}")
continue
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)

View File

@@ -1,126 +0,0 @@
import discord
from astrbot import logger
import sys
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# Discord Bot客户端
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
def __init__(self, token: str, proxy: str = None):
self.token = token
self.proxy = proxy
# 设置Intent权限遵循权限最小化原则
intents = discord.Intents.default()
intents.message_content = True # 订阅消息内容事件 (Privileged)
intents.members = True # 订阅成员事件 (Privileged)
# 初始化Bot
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
self.on_ready_once_callback = None
self._ready_once_fired = False
@override
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
if self.on_ready_once_callback and not self._ready_once_fired:
self._ready_once_fired = True
try:
await self.on_ready_once_callback()
except Exception as e:
logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
is_mentioned = self.user in message.mentions
return {
"message": message,
"bot_id": str(self.user.id),
"content": message.content,
"username": message.author.display_name,
"userid": str(message.author.id),
"message_id": str(message.id),
"channel_id": str(message.channel.id),
"guild_id": str(message.guild.id) if message.guild else None,
"type": "message",
"is_mentioned": is_mentioned,
"clean_content": message.clean_content,
}
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
return {
"interaction": interaction,
"bot_id": str(self.user.id),
"content": self._extract_interaction_content(interaction),
"username": interaction.user.display_name,
"userid": str(interaction.user.id),
"message_id": str(interaction.id),
"channel_id": str(interaction.channel_id)
if interaction.channel_id
else None,
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"type": "interaction",
}
@override
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
return
logger.debug(
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
)
if self.on_message_received:
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type
interaction_data = getattr(interaction, "data", {})
if not interaction_data:
return ""
if interaction_type == discord.InteractionType.application_command:
command_name = interaction_data.get("name", "")
if options := interaction_data.get("options", []):
params = " ".join(
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
)
return f"/{command_name} {params}"
return f"/{command_name}"
elif interaction_type == discord.InteractionType.component:
custom_id = interaction_data.get("custom_id", "")
component_type = interaction_data.get("component_type", "")
return f"component:{custom_id}:{component_type}"
return str(interaction_data)
async def start_polling(self):
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
@override
async def close(self):
"""关闭客户端"""
if not self.is_closed():
await super().close()

Some files were not shown because too many files have changed in this diff Show More