Compare commits
1 Commits
v4.3.1
...
features/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38486bc5aa |
51
.github/PULL_REQUEST_TEMPLATE.md
vendored
51
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,46 +1,19 @@
|
|||||||
<!-- 如果有的话,请指定此 PR 旨在解决的 ISSUE 编号。 -->
|
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||||
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
|
解决了 #XYZ
|
||||||
|
|
||||||
fixes #XYZ
|
### Motivation
|
||||||
|
|
||||||
---
|
<!--解释为什么要改动-->
|
||||||
|
|
||||||
### Motivation / 动机
|
### Modifications
|
||||||
|
|
||||||
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
|
<!--简单解释你的改动-->
|
||||||
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
|
|
||||||
|
|
||||||
### Modifications / 改动点
|
### Check
|
||||||
|
|
||||||
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
|
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
||||||
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
|
|
||||||
|
|
||||||
### Verification Steps / 验证步骤
|
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||||
|
- [ ] 👀 我的更改经过良好的测试
|
||||||
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤(例如:1. 导航到... 2. 点击...)。-->
|
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
||||||
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
|
- [ ] 😮 我的更改没有引入恶意代码
|
||||||
|
|
||||||
### Screenshots or Test Results / 运行截图或测试结果
|
|
||||||
|
|
||||||
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
|
|
||||||
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
|
|
||||||
|
|
||||||
### Compatibility & Breaking Changes / 兼容性与破坏性变更
|
|
||||||
|
|
||||||
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
|
|
||||||
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
|
|
||||||
|
|
||||||
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
|
|
||||||
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Checklist / 检查清单
|
|
||||||
|
|
||||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
|
||||||
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
|
|
||||||
|
|
||||||
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
|
|
||||||
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
|
|
||||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
|
|
||||||
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
|
|
||||||
|
|||||||
36
.github/auto_assign.yml
vendored
36
.github/auto_assign.yml
vendored
@@ -1,36 +0,0 @@
|
|||||||
# Set to true to add reviewers to pull requests
|
|
||||||
addReviewers: true
|
|
||||||
|
|
||||||
# Set to true to add assignees to pull requests
|
|
||||||
addAssignees: false
|
|
||||||
|
|
||||||
# A list of reviewers to be added to pull requests (GitHub user name)
|
|
||||||
reviewers:
|
|
||||||
- Soulter
|
|
||||||
- Raven95676
|
|
||||||
- Larch-C
|
|
||||||
- anka-afk
|
|
||||||
- advent259141
|
|
||||||
# - zouyonghe
|
|
||||||
|
|
||||||
# A number of reviewers added to the pull request
|
|
||||||
# Set 0 to add all the reviewers (default: 0)
|
|
||||||
numberOfReviewers: 2
|
|
||||||
|
|
||||||
# A list of assignees, overrides reviewers if set
|
|
||||||
# assignees:
|
|
||||||
# - assigneeA
|
|
||||||
|
|
||||||
# A number of assignees to add to the pull request
|
|
||||||
# Set to 0 to add all of the assignees.
|
|
||||||
# Uses numberOfReviewers if unset.
|
|
||||||
# numberOfAssignees: 2
|
|
||||||
|
|
||||||
# A list of keywords to be skipped the process that add reviewers if pull requests include it
|
|
||||||
skipKeywords:
|
|
||||||
- wip
|
|
||||||
- draft
|
|
||||||
|
|
||||||
# A list of users to be skipped by both the add reviewers and add assignees processes
|
|
||||||
# skipUsers:
|
|
||||||
# - dependabot[bot]
|
|
||||||
2
.github/workflows/auto_release.yml
vendored
2
.github/workflows/auto_release.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
|||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
|||||||
34
.github/workflows/code-format.yml
vendored
34
.github/workflows/code-format.yml
vendored
@@ -1,34 +0,0 @@
|
|||||||
name: Code Format Check
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
branches: [ master ]
|
|
||||||
push:
|
|
||||||
branches: [ master ]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
format-check:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v5
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install UV
|
|
||||||
run: pip install uv
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: uv sync
|
|
||||||
|
|
||||||
- name: Check code formatting with ruff
|
|
||||||
run: |
|
|
||||||
uv run ruff format --check .
|
|
||||||
|
|
||||||
- name: Check code style with ruff
|
|
||||||
run: |
|
|
||||||
uv run ruff check .
|
|
||||||
2
.github/workflows/coverage_test.yml
vendored
2
.github/workflows/coverage_test.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v5
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
1
.github/workflows/dashboard_ci.yml
vendored
1
.github/workflows/dashboard_ci.yml
vendored
@@ -37,7 +37,6 @@ jobs:
|
|||||||
!dist/**/*.md
|
!dist/**/*.md
|
||||||
|
|
||||||
- name: Create GitHub Release
|
- name: Create GitHub Release
|
||||||
if: github.event_name == 'push'
|
|
||||||
uses: ncipollo/release-action@v1
|
uses: ncipollo/release-action@v1
|
||||||
with:
|
with:
|
||||||
tag: release-${{ github.sha }}
|
tag: release-${{ github.sha }}
|
||||||
|
|||||||
31
.github/workflows/docker-image.yml
vendored
31
.github/workflows/docker-image.yml
vendored
@@ -27,33 +27,6 @@ jobs:
|
|||||||
if: github.event_name == 'workflow_dispatch'
|
if: github.event_name == 'workflow_dispatch'
|
||||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||||
|
|
||||||
- name: Check if version is pre-release
|
|
||||||
id: check-prerelease
|
|
||||||
run: |
|
|
||||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
|
||||||
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
|
|
||||||
else
|
|
||||||
version="${{ github.ref_name }}"
|
|
||||||
fi
|
|
||||||
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
|
|
||||||
echo "is_prerelease=true" >> $GITHUB_OUTPUT
|
|
||||||
echo "Version $version is a pre-release, will not push latest tag"
|
|
||||||
else
|
|
||||||
echo "is_prerelease=false" >> $GITHUB_OUTPUT
|
|
||||||
echo "Version $version is a stable release, will push latest tag"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Build Dashboard
|
|
||||||
run: |
|
|
||||||
cd dashboard
|
|
||||||
npm install
|
|
||||||
npm run build
|
|
||||||
mkdir -p dist/assets
|
|
||||||
echo $(git rev-parse HEAD) > dist/assets/version
|
|
||||||
cd ..
|
|
||||||
mkdir -p data
|
|
||||||
cp -r dashboard/dist data/
|
|
||||||
|
|
||||||
- name: Set QEMU
|
- name: Set QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
|
|
||||||
@@ -80,9 +53,9 @@ jobs:
|
|||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
push: true
|
push: true
|
||||||
tags: |
|
tags: |
|
||||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
|
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||||
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
|
ghcr.io/soulter/astrbot:latest
|
||||||
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
|
||||||
|
|
||||||
- name: Post build notifications
|
- name: Post build notifications
|
||||||
|
|||||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v10
|
- uses: actions/stale@v9
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: 'Stale issue message'
|
stale-issue-message: 'Stale issue message'
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -30,6 +30,4 @@ packages/python_interpreter/workplace
|
|||||||
.conda/
|
.conda/
|
||||||
.idea
|
.idea
|
||||||
pytest.ini
|
pytest.ini
|
||||||
.astrbot
|
.astrbot
|
||||||
|
|
||||||
uv.lock
|
|
||||||
21
Dockerfile
21
Dockerfile
@@ -4,6 +4,8 @@ WORKDIR /AstrBot
|
|||||||
COPY . /AstrBot/
|
COPY . /AstrBot/
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
nodejs \
|
||||||
|
npm \
|
||||||
gcc \
|
gcc \
|
||||||
build-essential \
|
build-essential \
|
||||||
python3-dev \
|
python3-dev \
|
||||||
@@ -11,20 +13,23 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
libssl-dev \
|
libssl-dev \
|
||||||
ca-certificates \
|
ca-certificates \
|
||||||
bash \
|
bash \
|
||||||
ffmpeg \
|
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y curl gnupg && \
|
|
||||||
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
|
||||||
apt-get install -y nodejs && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN python -m pip install uv
|
RUN python -m pip install uv
|
||||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||||
|
|
||||||
EXPOSE 6185
|
# 释出 ffmpeg
|
||||||
|
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||||
|
|
||||||
|
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||||
|
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||||
|
|
||||||
|
EXPOSE 6185
|
||||||
EXPOSE 6186
|
EXPOSE 6186
|
||||||
|
|
||||||
CMD [ "python", "main.py" ]
|
CMD [ "python", "main.py" ]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
97
README.md
97
README.md
@@ -6,6 +6,8 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||||
|
|
||||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||||
@@ -14,18 +16,18 @@
|
|||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||
|

|
||||||

|

|
||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
<a href="https://astrbot.app/">文档</a> |
|
<a href="https://astrbot.app/">查看文档</a> |
|
||||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
||||||
|
|
||||||
## 主要功能
|
## ✨ 主要功能
|
||||||
|
|
||||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||||
@@ -33,7 +35,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
|
|||||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||||
|
|
||||||
## 部署方式
|
## ✨ 使用方式
|
||||||
|
|
||||||
#### Docker 部署
|
#### Docker 部署
|
||||||
|
|
||||||
@@ -77,7 +79,9 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
|||||||
|
|
||||||
#### 手动部署
|
#### 手动部署
|
||||||
|
|
||||||
首先安装 uv:
|
> 推荐使用 `uv`。
|
||||||
|
|
||||||
|
首先,安装 uv:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install uv
|
pip install uv
|
||||||
@@ -92,25 +96,6 @@ uv run main.py
|
|||||||
|
|
||||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||||
|
|
||||||
## 🌍 社区
|
|
||||||
|
|
||||||
### QQ 群组
|
|
||||||
|
|
||||||
- 1 群:322154837
|
|
||||||
- 3 群:630166526
|
|
||||||
- 5 群:822130018
|
|
||||||
- 6 群:753075035
|
|
||||||
- 开发者群:975206796
|
|
||||||
- 开发者群(备份):295657329
|
|
||||||
|
|
||||||
### Telegram 群组
|
|
||||||
|
|
||||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
|
||||||
|
|
||||||
### Discord 群组
|
|
||||||
|
|
||||||
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
| 平台 | 支持性 |
|
| 平台 | 支持性 |
|
||||||
@@ -127,20 +112,22 @@ uv run main.py
|
|||||||
| Discord | ✔ |
|
| Discord | ✔ |
|
||||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||||
| Satori | ✔ |
|
| 微信对话开放平台 | 🚧 |
|
||||||
| Misskey | ✔ |
|
| WhatsApp | 🚧 |
|
||||||
|
| 小爱音响 | 🚧 |
|
||||||
|
|
||||||
## ⚡ 提供商支持情况
|
## ⚡ 提供商支持情况
|
||||||
|
|
||||||
| 名称 | 支持性 | 类型 | 备注 |
|
| 名称 | 支持性 | 类型 | 备注 |
|
||||||
| -------- | ------- | ------- | ------- |
|
| -------- | ------- | ------- | ------- |
|
||||||
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
|
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
||||||
| Anthropic | ✔ | 文本生成 | |
|
| Claude API | ✔ | 文本生成 | |
|
||||||
| Google Gemini | ✔ | 文本生成 | |
|
| Google Gemini API | ✔ | 文本生成 | |
|
||||||
| Dify | ✔ | LLMOps | |
|
| Dify | ✔ | LLMOps | |
|
||||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
|
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
||||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||||
@@ -156,6 +143,7 @@ uv run main.py
|
|||||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||||
|
|
||||||
|
|
||||||
## ❤️ 贡献
|
## ❤️ 贡献
|
||||||
|
|
||||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||||
@@ -174,6 +162,39 @@ pip install pre-commit
|
|||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🌟 支持
|
||||||
|
|
||||||
|
- Star 这个项目!
|
||||||
|
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||||
|
|
||||||
|
## ✨ Demo
|
||||||
|
|
||||||
|
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
|
||||||
|
|
||||||
|
<div align='center'>
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||||
|
|
||||||
|
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||||
|
|
||||||
|
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||||
|
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||||
|
|
||||||
|
_✨ 插件系统——部分插件展示 ✨_
|
||||||
|
|
||||||
|
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
|
||||||
|
|
||||||
|
_✨ WebUI ✨_
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
## ❤️ Special Thanks
|
## ❤️ Special Thanks
|
||||||
|
|
||||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||||
@@ -182,18 +203,10 @@ pre-commit install
|
|||||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
此外,本项目的诞生离不开以下开源项目的帮助:
|
此外,本项目的诞生离不开以下开源项目:
|
||||||
|
|
||||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||||
|
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||||
另外,一些同类型其他的活跃开源 Bot 项目:
|
|
||||||
|
|
||||||
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
|
|
||||||
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
|
|
||||||
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
|
|
||||||
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
|
|
||||||
- [KroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
|
||||||
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
|
|
||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
|
|
||||||
@@ -201,11 +214,13 @@ pre-commit install
|
|||||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
[](https://star-history.com/#soulter/astrbot&Date)
|
[](https://star-history.com/#soulter/astrbot&Date)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
</details>
|

|
||||||
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from astrbot.core.star.register import (
|
|||||||
register_permission_type as permission_type,
|
register_permission_type as permission_type,
|
||||||
register_custom_filter as custom_filter,
|
register_custom_filter as custom_filter,
|
||||||
register_on_astrbot_loaded as on_astrbot_loaded,
|
register_on_astrbot_loaded as on_astrbot_loaded,
|
||||||
register_on_platform_loaded as on_platform_loaded,
|
|
||||||
register_on_llm_request as on_llm_request,
|
register_on_llm_request as on_llm_request,
|
||||||
register_on_llm_response as on_llm_response,
|
register_on_llm_response as on_llm_response,
|
||||||
register_llm_tool as llm_tool,
|
register_llm_tool as llm_tool,
|
||||||
@@ -42,7 +41,6 @@ __all__ = [
|
|||||||
"custom_filter",
|
"custom_filter",
|
||||||
"PermissionType",
|
"PermissionType",
|
||||||
"on_astrbot_loaded",
|
"on_astrbot_loaded",
|
||||||
"on_platform_loaded",
|
|
||||||
"on_llm_request",
|
"on_llm_request",
|
||||||
"llm_tool",
|
"llm_tool",
|
||||||
"on_decorating_result",
|
"on_decorating_result",
|
||||||
|
|||||||
@@ -37,10 +37,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
|||||||
):
|
):
|
||||||
click.echo("正在安装管理面板...")
|
click.echo("正在安装管理面板...")
|
||||||
await download_dashboard(
|
await download_dashboard(
|
||||||
path="data/dashboard.zip",
|
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||||
extract_path=str(astrbot_root),
|
|
||||||
version=f"v{VERSION}",
|
|
||||||
latest=False,
|
|
||||||
)
|
)
|
||||||
click.echo("管理面板安装完成")
|
click.echo("管理面板安装完成")
|
||||||
|
|
||||||
@@ -53,10 +50,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
|||||||
version = dashboard_version.split("v")[1]
|
version = dashboard_version.split("v")[1]
|
||||||
click.echo(f"管理面板版本: {version}")
|
click.echo(f"管理面板版本: {version}")
|
||||||
await download_dashboard(
|
await download_dashboard(
|
||||||
path="data/dashboard.zip",
|
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||||
extract_path=str(astrbot_root),
|
|
||||||
version=f"v{VERSION}",
|
|
||||||
latest=False,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"下载管理面板失败: {e}")
|
click.echo(f"下载管理面板失败: {e}")
|
||||||
@@ -65,10 +59,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
|||||||
click.echo("初始化管理面板目录...")
|
click.echo("初始化管理面板目录...")
|
||||||
try:
|
try:
|
||||||
await download_dashboard(
|
await download_dashboard(
|
||||||
path=str(astrbot_root / "dashboard.zip"),
|
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||||
extract_path=str(astrbot_root),
|
|
||||||
version=f"v{VERSION}",
|
|
||||||
latest=False,
|
|
||||||
)
|
)
|
||||||
click.echo("管理面板初始化完成")
|
click.echo("管理面板初始化完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -124,17 +124,15 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
if metadata and all(
|
if metadata and all(
|
||||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||||
):
|
):
|
||||||
result.append(
|
result.append({
|
||||||
{
|
"name": str(metadata.get("name", "")),
|
||||||
"name": str(metadata.get("name", "")),
|
"desc": str(metadata.get("desc", "")),
|
||||||
"desc": str(metadata.get("desc", "")),
|
"version": str(metadata.get("version", "")),
|
||||||
"version": str(metadata.get("version", "")),
|
"author": str(metadata.get("author", "")),
|
||||||
"author": str(metadata.get("author", "")),
|
"repo": str(metadata.get("repo", "")),
|
||||||
"repo": str(metadata.get("repo", "")),
|
"status": PluginStatus.INSTALLED,
|
||||||
"status": PluginStatus.INSTALLED,
|
"local_path": str(plugin_dir),
|
||||||
"local_path": str(plugin_dir),
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取在线插件列表
|
# 获取在线插件列表
|
||||||
online_plugins = []
|
online_plugins = []
|
||||||
@@ -144,17 +142,15 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
for plugin_id, plugin_info in data.items():
|
for plugin_id, plugin_info in data.items():
|
||||||
online_plugins.append(
|
online_plugins.append({
|
||||||
{
|
"name": str(plugin_id),
|
||||||
"name": str(plugin_id),
|
"desc": str(plugin_info.get("desc", "")),
|
||||||
"desc": str(plugin_info.get("desc", "")),
|
"version": str(plugin_info.get("version", "")),
|
||||||
"version": str(plugin_info.get("version", "")),
|
"author": str(plugin_info.get("author", "")),
|
||||||
"author": str(plugin_info.get("author", "")),
|
"repo": str(plugin_info.get("repo", "")),
|
||||||
"repo": str(plugin_info.get("repo", "")),
|
"status": PluginStatus.NOT_INSTALLED,
|
||||||
"status": PluginStatus.NOT_INSTALLED,
|
"local_path": None,
|
||||||
"local_path": None,
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||||
|
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
|
|||||||
class Agent(Generic[TContext]):
|
class Agent(Generic[TContext]):
|
||||||
name: str
|
name: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
tools: list[str | FunctionTool] | None = None
|
tools: list[str, FunctionTool] | None = None
|
||||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class MCPClient:
|
|||||||
self.session: Optional[mcp.ClientSession] = None
|
self.session: Optional[mcp.ClientSession] = None
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
self.name: str | None = None
|
self.name = None
|
||||||
self.active: bool = True
|
self.active: bool = True
|
||||||
self.tools: list[mcp.Tool] = []
|
self.tools: list[mcp.Tool] = []
|
||||||
self.server_errlogs: list[str] = []
|
self.server_errlogs: list[str] = []
|
||||||
@@ -198,8 +198,6 @@ class MCPClient:
|
|||||||
|
|
||||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||||
"""List all tools from the server and save them to self.tools"""
|
"""List all tools from the server and save them to self.tools"""
|
||||||
if not self.session:
|
|
||||||
raise Exception("MCP Client is not initialized")
|
|
||||||
response = await self.session.list_tools()
|
response = await self.session.list_tools()
|
||||||
self.tools = response.tools
|
self.tools = response.tools
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||||||
import typing as T
|
import typing as T
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
|
||||||
|
|
||||||
class AgentResponseData(T.TypedDict):
|
class AgentResponseData(T.TypedDict):
|
||||||
chain: MessageChain
|
chain: MessageChain
|
||||||
|
|
||||||
|
|||||||
@@ -14,5 +14,4 @@ class ContextWrapper(Generic[TContext]):
|
|||||||
context: TContext
|
context: TContext
|
||||||
event: AstrMessageEvent
|
event: AstrMessageEvent
|
||||||
|
|
||||||
|
|
||||||
NoContext = ContextWrapper[None]
|
NoContext = ContextWrapper[None]
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
type="tool_direct_result"
|
type="tool_direct_result"
|
||||||
).base64_image(resource.blob)
|
).base64_image(res.content[0].data)
|
||||||
else:
|
else:
|
||||||
tool_call_result_blocks.append(
|
tool_call_result_blocks.append(
|
||||||
ToolCallMessageSegment(
|
ToolCallMessageSegment(
|
||||||
@@ -269,6 +269,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
yield MessageChain().message("返回的数据类型不受支持。")
|
yield MessageChain().message("返回的数据类型不受支持。")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.agent_hooks.on_tool_end(
|
||||||
|
self.run_context,
|
||||||
|
func_tool_name,
|
||||||
|
func_tool_args,
|
||||||
|
resp,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||||
|
)
|
||||||
elif resp is None:
|
elif resp is None:
|
||||||
# Tool 直接请求发送消息给用户
|
# Tool 直接请求发送消息给用户
|
||||||
# 这里我们将直接结束 Agent Loop。
|
# 这里我们将直接结束 Agent Loop。
|
||||||
@@ -278,17 +289,27 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
chain=res.chain, type="tool_direct_result"
|
chain=res.chain, type="tool_direct_result"
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
await self.agent_hooks.on_tool_end(
|
||||||
|
self.run_context, func_tool_name, func_tool_args, None
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.agent_hooks.on_tool_end(
|
await self.agent_hooks.on_tool_end(
|
||||||
self.run_context, func_tool, func_tool_args, None
|
self.run_context, func_tool_name, func_tool_args, None
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
self.run_context.event.clear_result()
|
self.run_context.event.clear_result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from typing import Awaitable, Callable, Literal, Any, Optional
|
from typing import Awaitable, Literal, Any, Optional
|
||||||
from .mcp_client import MCPClient
|
from .mcp_client import MCPClient
|
||||||
|
|
||||||
|
|
||||||
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
|
|||||||
class FunctionTool:
|
class FunctionTool:
|
||||||
"""A class representing a function tool that can be used in function calling."""
|
"""A class representing a function tool that can be used in function calling."""
|
||||||
|
|
||||||
name: str
|
name: str | None = None
|
||||||
parameters: dict | None = None
|
parameters: dict | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
handler: Callable[..., Awaitable[Any]] | None = None
|
handler: Awaitable | None = None
|
||||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||||
handler_module_path: str | None = None
|
handler_module_path: str | None = None
|
||||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||||
@@ -51,7 +51,7 @@ class ToolSet:
|
|||||||
This class provides methods to add, remove, and retrieve tools, as well as
|
This class provides methods to add, remove, and retrieve tools, as well as
|
||||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||||
|
|
||||||
def __init__(self, tools: list[FunctionTool] | None = None):
|
def __init__(self, tools: list[FunctionTool] = None):
|
||||||
self.tools: list[FunctionTool] = tools or []
|
self.tools: list[FunctionTool] = tools or []
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
@@ -79,13 +79,7 @@ class ToolSet:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||||
def add_func(
|
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
func_args: list,
|
|
||||||
desc: str,
|
|
||||||
handler: Callable[..., Awaitable[Any]],
|
|
||||||
):
|
|
||||||
"""Add a function tool to the set."""
|
"""Add a function tool to the set."""
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
@@ -110,7 +104,7 @@ class ToolSet:
|
|||||||
self.remove_tool(name)
|
self.remove_tool(name)
|
||||||
|
|
||||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||||
def get_func(self, name: str) -> FunctionTool | None:
|
def get_func(self, name: str) -> list[FunctionTool]:
|
||||||
"""Get all function tools."""
|
"""Get all function tools."""
|
||||||
return self.get_tool(name)
|
return self.get_tool(name)
|
||||||
|
|
||||||
@@ -131,11 +125,7 @@ class ToolSet:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
||||||
tool.parameters
|
|
||||||
and tool.parameters.get("properties")
|
|
||||||
or not omit_empty_parameter_field
|
|
||||||
):
|
|
||||||
func_def["function"]["parameters"] = tool.parameters
|
func_def["function"]["parameters"] = tool.parameters
|
||||||
|
|
||||||
result.append(func_def)
|
result.append(func_def)
|
||||||
@@ -145,14 +135,14 @@ class ToolSet:
|
|||||||
"""Convert tools to Anthropic API format."""
|
"""Convert tools to Anthropic API format."""
|
||||||
result = []
|
result = []
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
input_schema = {"type": "object"}
|
|
||||||
if tool.parameters:
|
|
||||||
input_schema["properties"] = tool.parameters.get("properties", {})
|
|
||||||
input_schema["required"] = tool.parameters.get("required", [])
|
|
||||||
tool_def = {
|
tool_def = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"input_schema": input_schema,
|
"input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": tool.parameters.get("properties", {}),
|
||||||
|
"required": tool.parameters.get("required", []),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
result.append(tool_def)
|
result.append(tool_def)
|
||||||
return result
|
return result
|
||||||
@@ -220,15 +210,14 @@ class ToolSet:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
tools = []
|
tools = [
|
||||||
for tool in self.tools:
|
{
|
||||||
d = {
|
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
|
"parameters": convert_schema(tool.parameters),
|
||||||
}
|
}
|
||||||
if tool.parameters:
|
for tool in self.tools
|
||||||
d["parameters"] = convert_schema(tool.parameters)
|
]
|
||||||
tools.append(d)
|
|
||||||
|
|
||||||
declarations = {}
|
declarations = {}
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -36,21 +36,13 @@ class AstrBotConfigManager:
|
|||||||
self.confs: dict[str, AstrBotConfig] = {}
|
self.confs: dict[str, AstrBotConfig] = {}
|
||||||
"""uuid / "default" -> AstrBotConfig"""
|
"""uuid / "default" -> AstrBotConfig"""
|
||||||
self.confs["default"] = default_config
|
self.confs["default"] = default_config
|
||||||
self.abconf_data = None
|
|
||||||
self._load_all_configs()
|
self._load_all_configs()
|
||||||
|
|
||||||
def _get_abconf_data(self) -> dict:
|
|
||||||
"""获取所有的 abconf 数据"""
|
|
||||||
if self.abconf_data is None:
|
|
||||||
self.abconf_data = self.sp.get(
|
|
||||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
|
||||||
)
|
|
||||||
return self.abconf_data
|
|
||||||
|
|
||||||
def _load_all_configs(self):
|
def _load_all_configs(self):
|
||||||
"""Load all configurations from the shared preferences."""
|
"""Load all configurations from the shared preferences."""
|
||||||
abconf_data = self._get_abconf_data()
|
abconf_data = self.sp.get(
|
||||||
self.abconf_data = abconf_data
|
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||||
|
)
|
||||||
for uuid_, meta in abconf_data.items():
|
for uuid_, meta in abconf_data.items():
|
||||||
filename = meta["path"]
|
filename = meta["path"]
|
||||||
conf_path = os.path.join(get_astrbot_config_path(), filename)
|
conf_path = os.path.join(get_astrbot_config_path(), filename)
|
||||||
@@ -80,7 +72,9 @@ class AstrBotConfigManager:
|
|||||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||||
"""
|
"""
|
||||||
# uuid -> { "umop": list, "path": str, "name": str }
|
# uuid -> { "umop": list, "path": str, "name": str }
|
||||||
abconf_data = self._get_abconf_data()
|
abconf_data = self.sp.get(
|
||||||
|
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||||
|
)
|
||||||
if isinstance(umo, MessageSession):
|
if isinstance(umo, MessageSession):
|
||||||
umo = str(umo)
|
umo = str(umo)
|
||||||
else:
|
else:
|
||||||
@@ -121,7 +115,6 @@ class AstrBotConfigManager:
|
|||||||
"name": random_word,
|
"name": random_word,
|
||||||
}
|
}
|
||||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||||
self.abconf_data = abconf_data
|
|
||||||
|
|
||||||
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
|
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
|
||||||
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
|
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
|
||||||
@@ -154,7 +147,9 @@ class AstrBotConfigManager:
|
|||||||
"""获取所有配置文件的元数据列表"""
|
"""获取所有配置文件的元数据列表"""
|
||||||
conf_list = []
|
conf_list = []
|
||||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||||
abconf_mapping = self._get_abconf_data()
|
abconf_mapping = self.sp.get(
|
||||||
|
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||||
|
)
|
||||||
for uuid_, meta in abconf_mapping.items():
|
for uuid_, meta in abconf_mapping.items():
|
||||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||||
return conf_list
|
return conf_list
|
||||||
@@ -223,7 +218,6 @@ class AstrBotConfigManager:
|
|||||||
# 从映射中移除
|
# 从映射中移除
|
||||||
del abconf_data[conf_id]
|
del abconf_data[conf_id]
|
||||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||||
self.abconf_data = abconf_data
|
|
||||||
|
|
||||||
logger.info(f"成功删除配置文件 {conf_id}")
|
logger.info(f"成功删除配置文件 {conf_id}")
|
||||||
return True
|
return True
|
||||||
@@ -269,7 +263,6 @@ class AstrBotConfigManager:
|
|||||||
|
|
||||||
# 保存更新
|
# 保存更新
|
||||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||||
self.abconf_data = abconf_data
|
|
||||||
logger.info(f"成功更新配置文件 {conf_id} 的信息")
|
logger.info(f"成功更新配置文件 {conf_id} 的信息")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.3.1"
|
VERSION = "4.0.0"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
@@ -51,21 +51,23 @@ DEFAULT_CONFIG = {
|
|||||||
"enable": True,
|
"enable": True,
|
||||||
"default_provider_id": "",
|
"default_provider_id": "",
|
||||||
"default_image_caption_provider_id": "",
|
"default_image_caption_provider_id": "",
|
||||||
|
"default_summarize_provider_id": "",
|
||||||
|
"context_exceed_calc_method": "token_size",
|
||||||
|
"max_token_size": 128000,
|
||||||
|
"max_context_length": 100,
|
||||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||||
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
|
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
|
||||||
"wake_prefix": "",
|
"wake_prefix": "",
|
||||||
"web_search": False,
|
"web_search": False,
|
||||||
"websearch_provider": "default",
|
"websearch_provider": "default",
|
||||||
"websearch_tavily_key": [],
|
"websearch_tavily_key": "",
|
||||||
"web_search_link": False,
|
"web_search_link": False,
|
||||||
"display_reasoning_text": False,
|
"display_reasoning_text": False,
|
||||||
"identifier": False,
|
"identifier": False,
|
||||||
"group_name_display": False,
|
|
||||||
"datetime_system_prompt": True,
|
"datetime_system_prompt": True,
|
||||||
"default_personality": "default",
|
"default_personality": "default",
|
||||||
"persona_pool": ["*"],
|
"persona_pool": ["*"],
|
||||||
"prompt_prefix": "{{prompt}}",
|
"prompt_prefix": "",
|
||||||
"max_context_length": -1,
|
|
||||||
"dequeue_context_length": 1,
|
"dequeue_context_length": 1,
|
||||||
"streaming_response": False,
|
"streaming_response": False,
|
||||||
"show_tool_use_status": False,
|
"show_tool_use_status": False,
|
||||||
@@ -104,7 +106,6 @@ DEFAULT_CONFIG = {
|
|||||||
"t2i_strategy": "remote",
|
"t2i_strategy": "remote",
|
||||||
"t2i_endpoint": "",
|
"t2i_endpoint": "",
|
||||||
"t2i_use_file_service": False,
|
"t2i_use_file_service": False,
|
||||||
"t2i_active_template": "base",
|
|
||||||
"http_proxy": "",
|
"http_proxy": "",
|
||||||
"no_proxy": ["localhost", "127.0.0.1", "::1"],
|
"no_proxy": ["localhost", "127.0.0.1", "::1"],
|
||||||
"dashboard": {
|
"dashboard": {
|
||||||
@@ -116,15 +117,6 @@ DEFAULT_CONFIG = {
|
|||||||
"port": 6185,
|
"port": 6185,
|
||||||
},
|
},
|
||||||
"platform": [],
|
"platform": [],
|
||||||
"platform_specific": {
|
|
||||||
# 平台特异配置:按平台分类,平台下按功能分组
|
|
||||||
"lark": {
|
|
||||||
"pre_ack_emoji": {"enable": False, "emojis": ["Typing"]},
|
|
||||||
},
|
|
||||||
"telegram": {
|
|
||||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"wake_prefix": ["/"],
|
"wake_prefix": ["/"],
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
"pip_install_arg": "",
|
"pip_install_arg": "",
|
||||||
@@ -245,16 +237,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"discord_guild_id_for_debug": "",
|
"discord_guild_id_for_debug": "",
|
||||||
"discord_activity_name": "",
|
"discord_activity_name": "",
|
||||||
},
|
},
|
||||||
"Misskey": {
|
|
||||||
"id": "misskey",
|
|
||||||
"type": "misskey",
|
|
||||||
"enable": False,
|
|
||||||
"misskey_instance_url": "https://misskey.example",
|
|
||||||
"misskey_token": "",
|
|
||||||
"misskey_default_visibility": "public",
|
|
||||||
"misskey_local_only": False,
|
|
||||||
"misskey_enable_chat": True,
|
|
||||||
},
|
|
||||||
"Slack": {
|
"Slack": {
|
||||||
"id": "slack",
|
"id": "slack",
|
||||||
"type": "slack",
|
"type": "slack",
|
||||||
@@ -267,49 +249,8 @@ CONFIG_METADATA_2 = {
|
|||||||
"slack_webhook_port": 6197,
|
"slack_webhook_port": 6197,
|
||||||
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
||||||
},
|
},
|
||||||
"Satori": {
|
|
||||||
"id": "satori",
|
|
||||||
"type": "satori",
|
|
||||||
"enable": False,
|
|
||||||
"satori_api_base_url": "http://localhost:5140/satori/v1",
|
|
||||||
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
|
|
||||||
"satori_token": "",
|
|
||||||
"satori_auto_reconnect": True,
|
|
||||||
"satori_heartbeat_interval": 10,
|
|
||||||
"satori_reconnect_delay": 5,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
"satori_api_base_url": {
|
|
||||||
"description": "Satori API 终结点",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "Satori API 的基础地址。",
|
|
||||||
},
|
|
||||||
"satori_endpoint": {
|
|
||||||
"description": "Satori WebSocket 终结点",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "Satori 事件的 WebSocket 端点。",
|
|
||||||
},
|
|
||||||
"satori_token": {
|
|
||||||
"description": "Satori 令牌",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "用于 Satori API 身份验证的令牌。",
|
|
||||||
},
|
|
||||||
"satori_auto_reconnect": {
|
|
||||||
"description": "启用自动重连",
|
|
||||||
"type": "bool",
|
|
||||||
"hint": "断开连接时是否自动重新连接 WebSocket。",
|
|
||||||
},
|
|
||||||
"satori_heartbeat_interval": {
|
|
||||||
"description": "Satori 心跳间隔",
|
|
||||||
"type": "int",
|
|
||||||
"hint": "发送心跳消息的间隔(秒)。",
|
|
||||||
},
|
|
||||||
"satori_reconnect_delay": {
|
|
||||||
"description": "Satori 重连延迟",
|
|
||||||
"type": "int",
|
|
||||||
"hint": "尝试重新连接前的延迟时间(秒)。",
|
|
||||||
},
|
|
||||||
"slack_connection_mode": {
|
"slack_connection_mode": {
|
||||||
"description": "Slack Connection Mode",
|
"description": "Slack Connection Mode",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -356,32 +297,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||||
},
|
},
|
||||||
"misskey_instance_url": {
|
|
||||||
"description": "Misskey 实例 URL",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "例如 https://misskey.example,填写 Bot 账号所在的 Misskey 实例地址",
|
|
||||||
},
|
|
||||||
"misskey_token": {
|
|
||||||
"description": "Misskey Access Token",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "连接服务设置生成的 API 鉴权访问令牌(Access token)",
|
|
||||||
},
|
|
||||||
"misskey_default_visibility": {
|
|
||||||
"description": "默认帖子可见性",
|
|
||||||
"type": "string",
|
|
||||||
"options": ["public", "home", "followers"],
|
|
||||||
"hint": "机器人发帖时的默认可见性设置。public:公开,home:主页时间线,followers:仅关注者。",
|
|
||||||
},
|
|
||||||
"misskey_local_only": {
|
|
||||||
"description": "仅限本站(不参与联合)",
|
|
||||||
"type": "bool",
|
|
||||||
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
|
|
||||||
},
|
|
||||||
"misskey_enable_chat": {
|
|
||||||
"description": "启用聊天消息响应",
|
|
||||||
"type": "bool",
|
|
||||||
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
|
||||||
},
|
|
||||||
"telegram_command_register": {
|
"telegram_command_register": {
|
||||||
"description": "Telegram 命令注册",
|
"description": "Telegram 命令注册",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -645,7 +560,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.openai.com/v1",
|
"api_base": "https://api.openai.com/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||||
},
|
},
|
||||||
@@ -660,7 +574,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "",
|
"api_base": "",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"xAI": {
|
"xAI": {
|
||||||
@@ -673,7 +586,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.x.ai/v1",
|
"api_base": "https://api.x.ai/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Anthropic": {
|
"Anthropic": {
|
||||||
@@ -703,7 +615,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||||
"api_base": "http://localhost:11434/v1",
|
"api_base": "http://localhost:11434/v1",
|
||||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"LM Studio": {
|
"LM Studio": {
|
||||||
@@ -717,7 +628,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "llama-3.1-8b",
|
"model": "llama-3.1-8b",
|
||||||
},
|
},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Gemini(OpenAI兼容)": {
|
"Gemini(OpenAI兼容)": {
|
||||||
@@ -733,7 +643,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "gemini-1.5-flash",
|
"model": "gemini-1.5-flash",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Gemini": {
|
"Gemini": {
|
||||||
@@ -774,7 +683,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.deepseek.com/v1",
|
"api_base": "https://api.deepseek.com/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"302.AI": {
|
"302.AI": {
|
||||||
@@ -787,7 +695,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.302.ai/v1",
|
"api_base": "https://api.302.ai/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"硅基流动": {
|
"硅基流动": {
|
||||||
@@ -803,7 +710,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "deepseek-ai/DeepSeek-V3",
|
"model": "deepseek-ai/DeepSeek-V3",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"PPIO派欧云": {
|
"PPIO派欧云": {
|
||||||
@@ -819,7 +725,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "deepseek/deepseek-r1",
|
"model": "deepseek/deepseek-r1",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"custom_extra_body": {},
|
|
||||||
},
|
},
|
||||||
"优云智算": {
|
"优云智算": {
|
||||||
"id": "compshare",
|
"id": "compshare",
|
||||||
@@ -833,7 +738,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "moonshotai/Kimi-K2-Instruct",
|
"model": "moonshotai/Kimi-K2-Instruct",
|
||||||
},
|
},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Kimi": {
|
"Kimi": {
|
||||||
@@ -846,7 +750,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"api_base": "https://api.moonshot.cn/v1",
|
"api_base": "https://api.moonshot.cn/v1",
|
||||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"智谱 AI": {
|
"智谱 AI": {
|
||||||
@@ -878,18 +781,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||||
},
|
},
|
||||||
"Coze": {
|
|
||||||
"id": "coze",
|
|
||||||
"provider": "coze",
|
|
||||||
"provider_type": "chat_completion",
|
|
||||||
"type": "coze",
|
|
||||||
"enable": True,
|
|
||||||
"coze_api_key": "",
|
|
||||||
"bot_id": "",
|
|
||||||
"coze_api_base": "https://api.coze.cn",
|
|
||||||
"timeout": 60,
|
|
||||||
"auto_save_history": True,
|
|
||||||
},
|
|
||||||
"阿里云百炼应用": {
|
"阿里云百炼应用": {
|
||||||
"id": "dashscope",
|
"id": "dashscope",
|
||||||
"provider": "dashscope",
|
"provider": "dashscope",
|
||||||
@@ -917,7 +808,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||||
"custom_extra_body": {},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"FastGPT": {
|
"FastGPT": {
|
||||||
@@ -929,7 +819,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://api.fastgpt.in/api/v1",
|
"api_base": "https://api.fastgpt.in/api/v1",
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
"custom_extra_body": {},
|
|
||||||
},
|
},
|
||||||
"Whisper(API)": {
|
"Whisper(API)": {
|
||||||
"id": "whisper",
|
"id": "whisper",
|
||||||
@@ -980,9 +869,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"provider_type": "text_to_speech",
|
"provider_type": "text_to_speech",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
|
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
|
||||||
"rate": "+0%",
|
|
||||||
"volume": "+0%",
|
|
||||||
"pitch": "+0Hz",
|
|
||||||
"timeout": 20,
|
"timeout": 20,
|
||||||
},
|
},
|
||||||
"GSV TTS(本地加载)": {
|
"GSV TTS(本地加载)": {
|
||||||
@@ -1174,12 +1060,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"render_type": "checkbox",
|
"render_type": "checkbox",
|
||||||
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
||||||
},
|
},
|
||||||
"custom_extra_body": {
|
|
||||||
"description": "自定义请求体参数",
|
|
||||||
"type": "dict",
|
|
||||||
"items": {},
|
|
||||||
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
|
|
||||||
},
|
|
||||||
"provider": {
|
"provider": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"invisible": True,
|
"invisible": True,
|
||||||
@@ -1756,26 +1636,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
||||||
"obvious": True,
|
"obvious": True,
|
||||||
},
|
},
|
||||||
"coze_api_key": {
|
|
||||||
"description": "Coze API Key",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "Coze API 密钥,用于访问 Coze 服务。",
|
|
||||||
},
|
|
||||||
"bot_id": {
|
|
||||||
"description": "Bot ID",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。",
|
|
||||||
},
|
|
||||||
"coze_api_base": {
|
|
||||||
"description": "API Base URL",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
|
|
||||||
},
|
|
||||||
"auto_save_history": {
|
|
||||||
"description": "由 Coze 管理对话记录",
|
|
||||||
"type": "bool",
|
|
||||||
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_settings": {
|
"provider_settings": {
|
||||||
@@ -1802,9 +1662,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"identifier": {
|
"identifier": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"group_name_display": {
|
|
||||||
"type": "bool",
|
|
||||||
},
|
|
||||||
"datetime_system_prompt": {
|
"datetime_system_prompt": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
@@ -1978,39 +1835,51 @@ CONFIG_METADATA_3 = {
|
|||||||
"_special": "select_provider",
|
"_special": "select_provider",
|
||||||
"hint": "留空时使用第一个模型。",
|
"hint": "留空时使用第一个模型。",
|
||||||
},
|
},
|
||||||
|
"provider_settings.default_summarize_provider_id": {
|
||||||
|
"description": "默认对话总结模型",
|
||||||
|
"type": "string",
|
||||||
|
"_special": "select_provider",
|
||||||
|
"hint": "留空代表不进行对话总结。可用于压缩上下文以减少 token 用量,并一定程度上保持历史聊天记忆。",
|
||||||
|
},
|
||||||
"provider_settings.default_image_caption_provider_id": {
|
"provider_settings.default_image_caption_provider_id": {
|
||||||
"description": "默认图片转述模型",
|
"description": "默认图片转述模型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"_special": "select_provider",
|
"_special": "select_provider",
|
||||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||||
},
|
},
|
||||||
"provider_stt_settings.enable": {
|
|
||||||
"description": "启用语音转文本",
|
|
||||||
"type": "bool",
|
|
||||||
"hint": "STT 总开关。",
|
|
||||||
},
|
|
||||||
"provider_stt_settings.provider_id": {
|
"provider_stt_settings.provider_id": {
|
||||||
"description": "默认语音转文本模型",
|
"description": "语音转文本模型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型。",
|
"hint": "留空代表不使用。",
|
||||||
"_special": "select_provider_stt",
|
"_special": "select_provider_stt",
|
||||||
"condition": {
|
|
||||||
"provider_stt_settings.enable": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"provider_tts_settings.enable": {
|
|
||||||
"description": "启用文本转语音",
|
|
||||||
"type": "bool",
|
|
||||||
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
|
|
||||||
},
|
},
|
||||||
"provider_tts_settings.provider_id": {
|
"provider_tts_settings.provider_id": {
|
||||||
"description": "默认文本转语音模型",
|
"description": "文本转语音模型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
|
"hint": "留空代表不使用。",
|
||||||
"_special": "select_provider_tts",
|
"_special": "select_provider_tts",
|
||||||
|
},
|
||||||
|
"provider_settings.context_exceed_calc_method": {
|
||||||
|
"description": "上下文超限的触发策略",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["token_size", "context_length"],
|
||||||
|
"labels": ["基于 Token 长度(估算)", "基于对话轮数"],
|
||||||
|
"hint": "如配置了对话总结模型,则触发时总结对话内容,否则丢弃最旧部分。"
|
||||||
|
},
|
||||||
|
"provider_settings.max_context_length": {
|
||||||
|
"description": "对话轮数上限",
|
||||||
|
"type": "int",
|
||||||
"condition": {
|
"condition": {
|
||||||
"provider_tts_settings.enable": True,
|
"provider_settings.context_exceed_calc_method": "context_length"
|
||||||
},
|
}
|
||||||
|
},
|
||||||
|
"provider_settings.max_token_size": {
|
||||||
|
"description": "Token 长度上限(估算)",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "超出这个数量时丢弃最旧的部分。",
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.context_exceed_calc_method": "token_size"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"provider_settings.image_caption_prompt": {
|
"provider_settings.image_caption_prompt": {
|
||||||
"description": "图片转述提示词",
|
"description": "图片转述提示词",
|
||||||
@@ -2055,9 +1924,7 @@ CONFIG_METADATA_3 = {
|
|||||||
},
|
},
|
||||||
"provider_settings.websearch_tavily_key": {
|
"provider_settings.websearch_tavily_key": {
|
||||||
"description": "Tavily API Key",
|
"description": "Tavily API Key",
|
||||||
"type": "list",
|
"type": "string",
|
||||||
"items": {"type": "string"},
|
|
||||||
"hint": "可添加多个 Key 进行轮询。",
|
|
||||||
"condition": {
|
"condition": {
|
||||||
"provider_settings.websearch_provider": "tavily",
|
"provider_settings.websearch_provider": "tavily",
|
||||||
},
|
},
|
||||||
@@ -2077,14 +1944,9 @@ CONFIG_METADATA_3 = {
|
|||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"provider_settings.identifier": {
|
"provider_settings.identifier": {
|
||||||
"description": "用户识别",
|
"description": "用户感知",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"provider_settings.group_name_display": {
|
|
||||||
"description": "显示群名称",
|
|
||||||
"type": "bool",
|
|
||||||
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
|
|
||||||
},
|
|
||||||
"provider_settings.datetime_system_prompt": {
|
"provider_settings.datetime_system_prompt": {
|
||||||
"description": "现实世界时间感知",
|
"description": "现实世界时间感知",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -2105,11 +1967,6 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "不支持流式回复的平台采取分段输出",
|
"description": "不支持流式回复的平台采取分段输出",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"provider_settings.max_context_length": {
|
|
||||||
"description": "最多携带对话轮数",
|
|
||||||
"type": "int",
|
|
||||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。",
|
|
||||||
},
|
|
||||||
"provider_settings.dequeue_context_length": {
|
"provider_settings.dequeue_context_length": {
|
||||||
"description": "丢弃对话轮数",
|
"description": "丢弃对话轮数",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
@@ -2118,14 +1975,12 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_settings.wake_prefix": {
|
"provider_settings.wake_prefix": {
|
||||||
"description": "LLM 聊天额外唤醒前缀 ",
|
"description": "LLM 聊天额外唤醒前缀 ",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "例子: 如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
|
||||||
},
|
},
|
||||||
"provider_settings.prompt_prefix": {
|
"provider_settings.prompt_prefix": {
|
||||||
"description": "用户提示词",
|
"description": "额外前缀提示词",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
|
|
||||||
},
|
},
|
||||||
"provider_tts_settings.dual_output": {
|
"provider_settings.dual_output": {
|
||||||
"description": "开启 TTS 时同时输出语音和文字内容",
|
"description": "开启 TTS 时同时输出语音和文字内容",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
@@ -2234,41 +2089,41 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "内容安全",
|
"description": "内容安全",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
"content_safety.also_use_in_response": {
|
"platform_settings.content_safety.also_use_in_response": {
|
||||||
"description": "同时检查模型的响应内容",
|
"description": "同时检查模型的响应内容",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"content_safety.baidu_aip.enable": {
|
"platform_settings.content_safety.baidu_aip.enable": {
|
||||||
"description": "使用百度内容安全审核",
|
"description": "使用百度内容安全审核",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "您需要手动安装 baidu-aip 库。",
|
"hint": "您需要手动安装 baidu-aip 库。",
|
||||||
},
|
},
|
||||||
"content_safety.baidu_aip.app_id": {
|
"platform_settings.content_safety.baidu_aip.app_id": {
|
||||||
"description": "App ID",
|
"description": "App ID",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"condition": {
|
"condition": {
|
||||||
"content_safety.baidu_aip.enable": True,
|
"platform_settings.content_safety.baidu_aip.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"content_safety.baidu_aip.api_key": {
|
"platform_settings.content_safety.baidu_aip.api_key": {
|
||||||
"description": "API Key",
|
"description": "API Key",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"condition": {
|
"condition": {
|
||||||
"content_safety.baidu_aip.enable": True,
|
"platform_settings.content_safety.baidu_aip.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"content_safety.baidu_aip.secret_key": {
|
"platform_settings.content_safety.baidu_aip.secret_key": {
|
||||||
"description": "Secret Key",
|
"description": "Secret Key",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"condition": {
|
"condition": {
|
||||||
"content_safety.baidu_aip.enable": True,
|
"platform_settings.content_safety.baidu_aip.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"content_safety.internal_keywords.enable": {
|
"platform_settings.content_safety.internal_keywords.enable": {
|
||||||
"description": "关键词检查",
|
"description": "关键词检查",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"content_safety.internal_keywords.extra_keywords": {
|
"platform_settings.content_safety.internal_keywords.extra_keywords": {
|
||||||
"description": "额外关键词",
|
"description": "额外关键词",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
@@ -2306,32 +2161,6 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "用户权限不足时是否回复",
|
"description": "用户权限不足时是否回复",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"platform_specific.lark.pre_ack_emoji.enable": {
|
|
||||||
"description": "[飞书] 启用预回应表情",
|
|
||||||
"type": "bool",
|
|
||||||
},
|
|
||||||
"platform_specific.lark.pre_ack_emoji.emojis": {
|
|
||||||
"description": "表情列表(飞书表情枚举名)",
|
|
||||||
"type": "list",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"hint": "表情枚举名参考:https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce",
|
|
||||||
"condition": {
|
|
||||||
"platform_specific.lark.pre_ack_emoji.enable": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"platform_specific.telegram.pre_ack_emoji.enable": {
|
|
||||||
"description": "[Telegram] 启用预回应表情",
|
|
||||||
"type": "bool",
|
|
||||||
},
|
|
||||||
"platform_specific.telegram.pre_ack_emoji.emojis": {
|
|
||||||
"description": "表情列表(Unicode)",
|
|
||||||
"type": "list",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"hint": "Telegram 仅支持固定反应集合,参考:https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9",
|
|
||||||
"condition": {
|
|
||||||
"platform_specific.telegram.pre_ack_emoji.enable": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2485,13 +2314,7 @@ CONFIG_METADATA_3_SYSTEM = {
|
|||||||
"condition": {
|
"condition": {
|
||||||
"t2i_strategy": "remote",
|
"t2i_strategy": "remote",
|
||||||
},
|
},
|
||||||
"_special": "t2i_template",
|
"_special": "t2i_template"
|
||||||
},
|
|
||||||
"t2i_active_template": {
|
|
||||||
"description": "当前应用的文转图渲染模板",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "此处的值由文转图模板管理页面进行维护。",
|
|
||||||
"invisible": True,
|
|
||||||
},
|
},
|
||||||
"log_level": {
|
"log_level": {
|
||||||
"description": "控制台日志级别",
|
"description": "控制台日志级别",
|
||||||
@@ -2524,11 +2347,6 @@ CONFIG_METADATA_3_SYSTEM = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
|
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
|
||||||
},
|
},
|
||||||
"no_proxy": {
|
|
||||||
"description": "直连地址列表",
|
|
||||||
"type": "list",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -87,25 +87,17 @@ class ConversationManager:
|
|||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||||
"""
|
"""
|
||||||
|
f = False
|
||||||
if not conversation_id:
|
if not conversation_id:
|
||||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||||
|
if conversation_id:
|
||||||
|
f = True
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
await self.db.delete_conversation(cid=conversation_id)
|
await self.db.delete_conversation(cid=conversation_id)
|
||||||
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
if f:
|
||||||
if curr_cid == conversation_id:
|
|
||||||
self.session_conversations.pop(unified_msg_origin, None)
|
self.session_conversations.pop(unified_msg_origin, None)
|
||||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||||
|
|
||||||
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
|
||||||
"""删除会话的所有对话
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
||||||
"""
|
|
||||||
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
|
||||||
self.session_conversations.pop(unified_msg_origin, None)
|
|
||||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
|
||||||
|
|
||||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||||
"""获取会话当前的对话 ID
|
"""获取会话当前的对话 ID
|
||||||
|
|
||||||
|
|||||||
@@ -154,17 +154,12 @@ class BaseDatabase(abc.ABC):
|
|||||||
"""Delete a conversation by its ID."""
|
"""Delete a conversation by its ID."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
|
||||||
"""Delete all conversations for a specific user."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def insert_platform_message_history(
|
async def insert_platform_message_history(
|
||||||
self,
|
self,
|
||||||
platform_id: str,
|
platform_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
content: dict,
|
content: list[dict],
|
||||||
sender_id: str | None = None,
|
sender_id: str | None = None,
|
||||||
sender_name: str | None = None,
|
sender_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -287,14 +282,3 @@ class BaseDatabase(abc.ABC):
|
|||||||
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
||||||
# """Get all LLM messages for a specific conversation."""
|
# """Get all LLM messages for a specific conversation."""
|
||||||
# ...
|
# ...
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def get_session_conversations(
|
|
||||||
self,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 20,
|
|
||||||
search_query: str | None = None,
|
|
||||||
platform: str | None = None,
|
|
||||||
) -> tuple[list[dict], int]:
|
|
||||||
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
|
|
||||||
...
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ async def do_migration_v4(
|
|||||||
await migration_webchat_data(db_helper, platform_id_map)
|
await migration_webchat_data(db_helper, platform_id_map)
|
||||||
|
|
||||||
# 执行偏好设置迁移
|
# 执行偏好设置迁移
|
||||||
await migration_preferences(db_helper, platform_id_map)
|
await migration_preferences(db_helper,platform_id_map)
|
||||||
|
|
||||||
# 执行平台统计表迁移
|
# 执行平台统计表迁移
|
||||||
await migration_platform_table(db_helper, platform_id_map)
|
await migration_platform_table(db_helper, platform_id_map)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
|
|
||||||
_VT = TypeVar("_VT")
|
_VT = TypeVar("_VT")
|
||||||
|
|
||||||
|
|
||||||
class SharedPreferences:
|
class SharedPreferences:
|
||||||
def __init__(self, path=None):
|
def __init__(self, path=None):
|
||||||
if path is None:
|
if path is None:
|
||||||
@@ -43,5 +42,4 @@ class SharedPreferences:
|
|||||||
self._data.clear()
|
self._data.clear()
|
||||||
self._save_preferences()
|
self._save_preferences()
|
||||||
|
|
||||||
|
|
||||||
sp = SharedPreferences()
|
sp = SharedPreferences()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from astrbot.core.db.po import Platform, Stats
|
|||||||
from typing import Tuple, List, Dict, Any
|
from typing import Tuple, List, Dict, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Conversation:
|
class Conversation:
|
||||||
"""LLM 对话存储
|
"""LLM 对话存储
|
||||||
@@ -77,7 +76,7 @@ PRAGMA encoding = 'UTF-8';
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase:
|
class SQLiteDatabase():
|
||||||
def __init__(self, db_path: str) -> None:
|
def __init__(self, db_path: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
|||||||
@@ -75,9 +75,7 @@ class Persona(SQLModel, table=True):
|
|||||||
|
|
||||||
__tablename__ = "personas"
|
__tablename__ = "personas"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
|
||||||
)
|
|
||||||
persona_id: str = Field(max_length=255, nullable=False)
|
persona_id: str = Field(max_length=255, nullable=False)
|
||||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||||
@@ -137,9 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|||||||
|
|
||||||
__tablename__ = "platform_message_history"
|
__tablename__ = "platform_message_history"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
|
||||||
)
|
|
||||||
platform_id: str = Field(nullable=False)
|
platform_id: str = Field(nullable=False)
|
||||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||||
@@ -162,8 +158,8 @@ class Attachment(SQLModel, table=True):
|
|||||||
|
|
||||||
__tablename__ = "attachments"
|
__tablename__ = "attachments"
|
||||||
|
|
||||||
inner_attachment_id: int | None = Field(
|
inner_attachment_id: int = Field(
|
||||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||||
)
|
)
|
||||||
attachment_id: str = Field(
|
attachment_id: str = Field(
|
||||||
max_length=36,
|
max_length=36,
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ from astrbot.core.db.po import (
|
|||||||
SQLModel,
|
SQLModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sqlmodel import select, update, delete, text, func, or_, desc, col
|
from sqlalchemy import select, update, delete, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||||
|
|
||||||
@@ -40,10 +41,10 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
|
|
||||||
async def insert_platform_stats(
|
async def insert_platform_stats(
|
||||||
self,
|
self,
|
||||||
platform_id,
|
platform_id: str,
|
||||||
platform_type,
|
platform_type: str,
|
||||||
count=1,
|
count: int = 1,
|
||||||
timestamp=None,
|
timestamp: datetime = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Insert a new platform statistic record."""
|
"""Insert a new platform statistic record."""
|
||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
@@ -74,9 +75,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
|
||||||
PlatformStat
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
count = result.scalar_one_or_none()
|
count = result.scalar_one_or_none()
|
||||||
return count if count is not None else 0
|
return count if count is not None else 0
|
||||||
@@ -96,7 +95,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
"""),
|
"""),
|
||||||
{"start_time": start_time},
|
{"start_time": start_time},
|
||||||
)
|
)
|
||||||
return list(result.scalars().all())
|
return result.scalars().all()
|
||||||
|
|
||||||
# ====
|
# ====
|
||||||
# Conversation Management
|
# Conversation Management
|
||||||
@@ -112,7 +111,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
if platform_id:
|
if platform_id:
|
||||||
query = query.where(ConversationV2.platform_id == platform_id)
|
query = query.where(ConversationV2.platform_id == platform_id)
|
||||||
# order by
|
# order by
|
||||||
query = query.order_by(desc(ConversationV2.created_at))
|
query = query.order_by(ConversationV2.created_at.desc())
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
|
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
@@ -130,7 +129,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(ConversationV2)
|
select(ConversationV2)
|
||||||
.order_by(desc(ConversationV2.created_at))
|
.order_by(ConversationV2.created_at.desc())
|
||||||
.offset(offset)
|
.offset(offset)
|
||||||
.limit(page_size)
|
.limit(page_size)
|
||||||
)
|
)
|
||||||
@@ -151,25 +150,11 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
|
|
||||||
if platform_ids:
|
if platform_ids:
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
col(ConversationV2.platform_id).in_(platform_ids)
|
ConversationV2.platform_id.in_(platform_ids)
|
||||||
)
|
)
|
||||||
if search_query:
|
if search_query:
|
||||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
or_(
|
ConversationV2.title.ilike(f"%{search_query}%")
|
||||||
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
|
||||||
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
|
||||||
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
|
||||||
for msg_type in kwargs["message_types"]:
|
|
||||||
base_query = base_query.where(
|
|
||||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
|
|
||||||
)
|
|
||||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
|
||||||
base_query = base_query.where(
|
|
||||||
col(ConversationV2.platform_id).in_(kwargs["platforms"])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get total count matching the filters
|
# Get total count matching the filters
|
||||||
@@ -180,7 +165,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
# Get paginated results
|
# Get paginated results
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
result_query = (
|
result_query = (
|
||||||
base_query.order_by(desc(ConversationV2.created_at))
|
base_query.order_by(ConversationV2.created_at.desc())
|
||||||
.offset(offset)
|
.offset(offset)
|
||||||
.limit(page_size)
|
.limit(page_size)
|
||||||
)
|
)
|
||||||
@@ -226,7 +211,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
query = update(ConversationV2).where(
|
query = update(ConversationV2).where(
|
||||||
col(ConversationV2.conversation_id) == cid
|
ConversationV2.conversation_id == cid
|
||||||
)
|
)
|
||||||
values = {}
|
values = {}
|
||||||
if title is not None:
|
if title is not None:
|
||||||
@@ -246,126 +231,9 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(ConversationV2).where(
|
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||||
col(ConversationV2.conversation_id) == cid
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
async with session.begin():
|
|
||||||
await session.execute(
|
|
||||||
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_session_conversations(
|
|
||||||
self,
|
|
||||||
page=1,
|
|
||||||
page_size=20,
|
|
||||||
search_query=None,
|
|
||||||
platform=None,
|
|
||||||
) -> tuple[list[dict], int]:
|
|
||||||
"""Get paginated session conversations with joined conversation and persona details."""
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
offset = (page - 1) * page_size
|
|
||||||
|
|
||||||
base_query = (
|
|
||||||
select(
|
|
||||||
col(Preference.scope_id).label("session_id"),
|
|
||||||
func.json_extract(Preference.value, "$.val").label(
|
|
||||||
"conversation_id"
|
|
||||||
), # type: ignore
|
|
||||||
col(ConversationV2.persona_id).label("persona_id"),
|
|
||||||
col(ConversationV2.title).label("title"),
|
|
||||||
col(Persona.persona_id).label("persona_name"),
|
|
||||||
)
|
|
||||||
.select_from(Preference)
|
|
||||||
.outerjoin(
|
|
||||||
ConversationV2,
|
|
||||||
func.json_extract(Preference.value, "$.val")
|
|
||||||
== ConversationV2.conversation_id,
|
|
||||||
)
|
|
||||||
.outerjoin(
|
|
||||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
|
||||||
)
|
|
||||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 搜索筛选
|
|
||||||
if search_query:
|
|
||||||
search_pattern = f"%{search_query}%"
|
|
||||||
base_query = base_query.where(
|
|
||||||
or_(
|
|
||||||
col(Preference.scope_id).ilike(search_pattern),
|
|
||||||
col(ConversationV2.title).ilike(search_pattern),
|
|
||||||
col(Persona.persona_id).ilike(search_pattern),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 平台筛选
|
|
||||||
if platform:
|
|
||||||
platform_pattern = f"{platform}:%"
|
|
||||||
base_query = base_query.where(
|
|
||||||
col(Preference.scope_id).like(platform_pattern)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 排序
|
|
||||||
base_query = base_query.order_by(Preference.scope_id)
|
|
||||||
|
|
||||||
# 分页结果
|
|
||||||
result_query = base_query.offset(offset).limit(page_size)
|
|
||||||
result = await session.execute(result_query)
|
|
||||||
rows = result.fetchall()
|
|
||||||
|
|
||||||
# 查询总数(应用相同的筛选条件)
|
|
||||||
count_base_query = (
|
|
||||||
select(func.count(col(Preference.scope_id)))
|
|
||||||
.select_from(Preference)
|
|
||||||
.outerjoin(
|
|
||||||
ConversationV2,
|
|
||||||
func.json_extract(Preference.value, "$.val")
|
|
||||||
== ConversationV2.conversation_id,
|
|
||||||
)
|
|
||||||
.outerjoin(
|
|
||||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
|
||||||
)
|
|
||||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 应用相同的搜索和平台筛选条件到计数查询
|
|
||||||
if search_query:
|
|
||||||
search_pattern = f"%{search_query}%"
|
|
||||||
count_base_query = count_base_query.where(
|
|
||||||
or_(
|
|
||||||
col(Preference.scope_id).ilike(search_pattern),
|
|
||||||
col(ConversationV2.title).ilike(search_pattern),
|
|
||||||
col(Persona.persona_id).ilike(search_pattern),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if platform:
|
|
||||||
platform_pattern = f"{platform}:%"
|
|
||||||
count_base_query = count_base_query.where(
|
|
||||||
col(Preference.scope_id).like(platform_pattern)
|
|
||||||
)
|
|
||||||
|
|
||||||
total_result = await session.execute(count_base_query)
|
|
||||||
total = total_result.scalar() or 0
|
|
||||||
|
|
||||||
sessions_data = [
|
|
||||||
{
|
|
||||||
"session_id": row.session_id,
|
|
||||||
"conversation_id": row.conversation_id,
|
|
||||||
"persona_id": row.persona_id,
|
|
||||||
"title": row.title,
|
|
||||||
"persona_name": row.persona_name,
|
|
||||||
}
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
return sessions_data, total
|
|
||||||
|
|
||||||
async def insert_platform_message_history(
|
async def insert_platform_message_history(
|
||||||
self,
|
self,
|
||||||
platform_id,
|
platform_id,
|
||||||
@@ -399,9 +267,9 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
cutoff_time = now - timedelta(seconds=offset_sec)
|
cutoff_time = now - timedelta(seconds=offset_sec)
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(PlatformMessageHistory).where(
|
delete(PlatformMessageHistory).where(
|
||||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
PlatformMessageHistory.platform_id == platform_id,
|
||||||
col(PlatformMessageHistory.user_id) == user_id,
|
PlatformMessageHistory.user_id == user_id,
|
||||||
col(PlatformMessageHistory.created_at) < cutoff_time,
|
PlatformMessageHistory.created_at < cutoff_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -418,7 +286,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
PlatformMessageHistory.platform_id == platform_id,
|
PlatformMessageHistory.platform_id == platform_id,
|
||||||
PlatformMessageHistory.user_id == user_id,
|
PlatformMessageHistory.user_id == user_id,
|
||||||
)
|
)
|
||||||
.order_by(desc(PlatformMessageHistory.created_at))
|
.order_by(PlatformMessageHistory.created_at.desc())
|
||||||
)
|
)
|
||||||
result = await session.execute(query.offset(offset).limit(page_size))
|
result = await session.execute(query.offset(offset).limit(page_size))
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
@@ -440,7 +308,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
"""Get an attachment by its ID."""
|
"""Get an attachment by its ID."""
|
||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
query = select(Attachment).where(Attachment.id == attachment_id)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
@@ -483,7 +351,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
query = update(Persona).where(Persona.persona_id == persona_id)
|
||||||
values = {}
|
values = {}
|
||||||
if system_prompt is not None:
|
if system_prompt is not None:
|
||||||
values["system_prompt"] = system_prompt
|
values["system_prompt"] = system_prompt
|
||||||
@@ -503,7 +371,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(Persona).where(col(Persona.persona_id) == persona_id)
|
delete(Persona).where(Persona.persona_id == persona_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||||
@@ -558,9 +426,9 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with session.begin():
|
async with session.begin():
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(Preference).where(
|
delete(Preference).where(
|
||||||
col(Preference.scope) == scope,
|
Preference.scope == scope,
|
||||||
col(Preference.scope_id) == scope_id,
|
Preference.scope_id == scope_id,
|
||||||
col(Preference.key) == key,
|
Preference.key == key,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -572,8 +440,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with session.begin():
|
async with session.begin():
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(Preference).where(
|
delete(Preference).where(
|
||||||
col(Preference.scope) == scope,
|
Preference.scope == scope, Preference.scope_id == scope_id
|
||||||
col(Preference.scope_id) == scope_id,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -600,7 +467,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
DeprecatedPlatformStat(
|
DeprecatedPlatformStat(
|
||||||
name=data.platform_id,
|
name=data.platform_id,
|
||||||
count=data.count,
|
count=data.count,
|
||||||
timestamp=int(data.timestamp.timestamp()),
|
timestamp=data.timestamp.timestamp(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return deprecated_stats
|
return deprecated_stats
|
||||||
@@ -658,7 +525,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
DeprecatedPlatformStat(
|
DeprecatedPlatformStat(
|
||||||
name=platform_id,
|
name=platform_id,
|
||||||
count=count,
|
count=count,
|
||||||
timestamp=int(start_time.timestamp()),
|
timestamp=start_time.timestamp(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return deprecated_stats
|
return deprecated_stats
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .vec_db import FaissVecDB
|
from .vec_db import FaissVecDB
|
||||||
|
|
||||||
__all__ = ["FaissVecDB"]
|
__all__ = ["FaissVecDB"]
|
||||||
@@ -113,8 +113,7 @@ class FaissVecDB(BaseVecDB):
|
|||||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
||||||
)
|
)
|
||||||
top_k_results = [
|
top_k_results = [
|
||||||
top_k_results[reranked_result.index]
|
top_k_results[reranked_result.index] for reranked_result in reranked_results
|
||||||
for reranked_result in reranked_results
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return top_k_results
|
return top_k_results
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ class InitialLoader:
|
|||||||
self.db = db
|
self.db = db
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
self.webui_dir: str | None = None
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||||
@@ -36,10 +35,8 @@ class InitialLoader:
|
|||||||
|
|
||||||
core_task = core_lifecycle.start()
|
core_task = core_lifecycle.start()
|
||||||
|
|
||||||
webui_dir = self.webui_dir
|
|
||||||
|
|
||||||
self.dashboard_server = AstrBotDashboard(
|
self.dashboard_server = AstrBotDashboard(
|
||||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
|
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||||
)
|
)
|
||||||
task = asyncio.gather(
|
task = asyncio.gather(
|
||||||
core_task, self.dashboard_server.run()
|
core_task, self.dashboard_server.run()
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||||
|
|
||||||
|
|
||||||
class ComponentType(str, Enum):
|
class ComponentType(Enum):
|
||||||
Plain = "Plain" # 纯文本消息
|
Plain = "Plain" # 纯文本消息
|
||||||
Face = "Face" # QQ表情
|
Face = "Face" # QQ表情
|
||||||
Record = "Record" # 语音
|
Record = "Record" # 语音
|
||||||
@@ -108,7 +108,7 @@ class BaseMessageComponent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Plain(BaseMessageComponent):
|
class Plain(BaseMessageComponent):
|
||||||
type = ComponentType.Plain
|
type: ComponentType = "Plain"
|
||||||
text: str
|
text: str
|
||||||
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
||||||
|
|
||||||
@@ -128,9 +128,8 @@ class Plain(BaseMessageComponent):
|
|||||||
async def to_dict(self):
|
async def to_dict(self):
|
||||||
return {"type": "text", "data": {"text": self.text}}
|
return {"type": "text", "data": {"text": self.text}}
|
||||||
|
|
||||||
|
|
||||||
class Face(BaseMessageComponent):
|
class Face(BaseMessageComponent):
|
||||||
type = ComponentType.Face
|
type: ComponentType = "Face"
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -138,7 +137,7 @@ class Face(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Record(BaseMessageComponent):
|
class Record(BaseMessageComponent):
|
||||||
type = ComponentType.Record
|
type: ComponentType = "Record"
|
||||||
file: T.Optional[str] = ""
|
file: T.Optional[str] = ""
|
||||||
magic: T.Optional[bool] = False
|
magic: T.Optional[bool] = False
|
||||||
url: T.Optional[str] = ""
|
url: T.Optional[str] = ""
|
||||||
@@ -165,24 +164,19 @@ class Record(BaseMessageComponent):
|
|||||||
return Record(file=url, **_)
|
return Record(file=url, **_)
|
||||||
raise Exception("not a valid url")
|
raise Exception("not a valid url")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fromBase64(bs64_data: str, **_):
|
|
||||||
return Record(file=f"base64://{bs64_data}", **_)
|
|
||||||
|
|
||||||
async def convert_to_file_path(self) -> str:
|
async def convert_to_file_path(self) -> str:
|
||||||
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 语音的本地路径,以绝对路径表示。
|
str: 语音的本地路径,以绝对路径表示。
|
||||||
"""
|
"""
|
||||||
if not self.file:
|
if self.file and self.file.startswith("file:///"):
|
||||||
raise Exception(f"not a valid file: {self.file}")
|
file_path = self.file[8:]
|
||||||
if self.file.startswith("file:///"):
|
return file_path
|
||||||
return self.file[8:]
|
elif self.file and self.file.startswith("http"):
|
||||||
elif self.file.startswith("http"):
|
|
||||||
file_path = await download_image_by_url(self.file)
|
file_path = await download_image_by_url(self.file)
|
||||||
return os.path.abspath(file_path)
|
return os.path.abspath(file_path)
|
||||||
elif self.file.startswith("base64://"):
|
elif self.file and self.file.startswith("base64://"):
|
||||||
bs64_data = self.file.removeprefix("base64://")
|
bs64_data = self.file.removeprefix("base64://")
|
||||||
image_bytes = base64.b64decode(bs64_data)
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
@@ -191,7 +185,8 @@ class Record(BaseMessageComponent):
|
|||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
return os.path.abspath(file_path)
|
return os.path.abspath(file_path)
|
||||||
elif os.path.exists(self.file):
|
elif os.path.exists(self.file):
|
||||||
return os.path.abspath(self.file)
|
file_path = self.file
|
||||||
|
return os.path.abspath(file_path)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"not a valid file: {self.file}")
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
|
||||||
@@ -202,14 +197,12 @@ class Record(BaseMessageComponent):
|
|||||||
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
"""
|
"""
|
||||||
# convert to base64
|
# convert to base64
|
||||||
if not self.file:
|
if self.file and self.file.startswith("file:///"):
|
||||||
raise Exception(f"not a valid file: {self.file}")
|
|
||||||
if self.file.startswith("file:///"):
|
|
||||||
bs64_data = file_to_base64(self.file[8:])
|
bs64_data = file_to_base64(self.file[8:])
|
||||||
elif self.file.startswith("http"):
|
elif self.file and self.file.startswith("http"):
|
||||||
file_path = await download_image_by_url(self.file)
|
file_path = await download_image_by_url(self.file)
|
||||||
bs64_data = file_to_base64(file_path)
|
bs64_data = file_to_base64(file_path)
|
||||||
elif self.file.startswith("base64://"):
|
elif self.file and self.file.startswith("base64://"):
|
||||||
bs64_data = self.file
|
bs64_data = self.file
|
||||||
elif os.path.exists(self.file):
|
elif os.path.exists(self.file):
|
||||||
bs64_data = file_to_base64(self.file)
|
bs64_data = file_to_base64(self.file)
|
||||||
@@ -243,7 +236,7 @@ class Record(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Video(BaseMessageComponent):
|
class Video(BaseMessageComponent):
|
||||||
type = ComponentType.Video
|
type: ComponentType = "Video"
|
||||||
file: str
|
file: str
|
||||||
cover: T.Optional[str] = ""
|
cover: T.Optional[str] = ""
|
||||||
c: T.Optional[int] = 2
|
c: T.Optional[int] = 2
|
||||||
@@ -329,7 +322,7 @@ class Video(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class At(BaseMessageComponent):
|
class At(BaseMessageComponent):
|
||||||
type = ComponentType.At
|
type: ComponentType = "At"
|
||||||
qq: T.Union[int, str] # 此处str为all时代表所有人
|
qq: T.Union[int, str] # 此处str为all时代表所有人
|
||||||
name: T.Optional[str] = ""
|
name: T.Optional[str] = ""
|
||||||
|
|
||||||
@@ -351,28 +344,28 @@ class AtAll(At):
|
|||||||
|
|
||||||
|
|
||||||
class RPS(BaseMessageComponent): # TODO
|
class RPS(BaseMessageComponent): # TODO
|
||||||
type = ComponentType.RPS
|
type: ComponentType = "RPS"
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Dice(BaseMessageComponent): # TODO
|
class Dice(BaseMessageComponent): # TODO
|
||||||
type = ComponentType.Dice
|
type: ComponentType = "Dice"
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Shake(BaseMessageComponent): # TODO
|
class Shake(BaseMessageComponent): # TODO
|
||||||
type = ComponentType.Shake
|
type: ComponentType = "Shake"
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Anonymous(BaseMessageComponent): # TODO
|
class Anonymous(BaseMessageComponent): # TODO
|
||||||
type = ComponentType.Anonymous
|
type: ComponentType = "Anonymous"
|
||||||
ignore: T.Optional[bool] = False
|
ignore: T.Optional[bool] = False
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -380,7 +373,7 @@ class Anonymous(BaseMessageComponent): # TODO
|
|||||||
|
|
||||||
|
|
||||||
class Share(BaseMessageComponent):
|
class Share(BaseMessageComponent):
|
||||||
type = ComponentType.Share
|
type: ComponentType = "Share"
|
||||||
url: str
|
url: str
|
||||||
title: str
|
title: str
|
||||||
content: T.Optional[str] = ""
|
content: T.Optional[str] = ""
|
||||||
@@ -391,7 +384,7 @@ class Share(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Contact(BaseMessageComponent): # TODO
|
class Contact(BaseMessageComponent): # TODO
|
||||||
type = ComponentType.Contact
|
type: ComponentType = "Contact"
|
||||||
_type: str # type 字段冲突
|
_type: str # type 字段冲突
|
||||||
id: T.Optional[int] = 0
|
id: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -400,7 +393,7 @@ class Contact(BaseMessageComponent): # TODO
|
|||||||
|
|
||||||
|
|
||||||
class Location(BaseMessageComponent): # TODO
|
class Location(BaseMessageComponent): # TODO
|
||||||
type = ComponentType.Location
|
type: ComponentType = "Location"
|
||||||
lat: float
|
lat: float
|
||||||
lon: float
|
lon: float
|
||||||
title: T.Optional[str] = ""
|
title: T.Optional[str] = ""
|
||||||
@@ -411,7 +404,7 @@ class Location(BaseMessageComponent): # TODO
|
|||||||
|
|
||||||
|
|
||||||
class Music(BaseMessageComponent):
|
class Music(BaseMessageComponent):
|
||||||
type = ComponentType.Music
|
type: ComponentType = "Music"
|
||||||
_type: str
|
_type: str
|
||||||
id: T.Optional[int] = 0
|
id: T.Optional[int] = 0
|
||||||
url: T.Optional[str] = ""
|
url: T.Optional[str] = ""
|
||||||
@@ -428,7 +421,7 @@ class Music(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Image(BaseMessageComponent):
|
class Image(BaseMessageComponent):
|
||||||
type = ComponentType.Image
|
type: ComponentType = "Image"
|
||||||
file: T.Optional[str] = ""
|
file: T.Optional[str] = ""
|
||||||
_type: T.Optional[str] = ""
|
_type: T.Optional[str] = ""
|
||||||
subType: T.Optional[int] = 0
|
subType: T.Optional[int] = 0
|
||||||
@@ -471,15 +464,14 @@ class Image(BaseMessageComponent):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 图片的本地路径,以绝对路径表示。
|
str: 图片的本地路径,以绝对路径表示。
|
||||||
"""
|
"""
|
||||||
url = self.url or self.file
|
url = self.url if self.url else self.file
|
||||||
if not url:
|
if url and url.startswith("file:///"):
|
||||||
raise ValueError("No valid file or URL provided")
|
image_file_path = url[8:]
|
||||||
if url.startswith("file:///"):
|
return image_file_path
|
||||||
return url[8:]
|
elif url and url.startswith("http"):
|
||||||
elif url.startswith("http"):
|
|
||||||
image_file_path = await download_image_by_url(url)
|
image_file_path = await download_image_by_url(url)
|
||||||
return os.path.abspath(image_file_path)
|
return os.path.abspath(image_file_path)
|
||||||
elif url.startswith("base64://"):
|
elif url and url.startswith("base64://"):
|
||||||
bs64_data = url.removeprefix("base64://")
|
bs64_data = url.removeprefix("base64://")
|
||||||
image_bytes = base64.b64decode(bs64_data)
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
@@ -488,7 +480,8 @@ class Image(BaseMessageComponent):
|
|||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
return os.path.abspath(image_file_path)
|
return os.path.abspath(image_file_path)
|
||||||
elif os.path.exists(url):
|
elif os.path.exists(url):
|
||||||
return os.path.abspath(url)
|
image_file_path = url
|
||||||
|
return os.path.abspath(image_file_path)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"not a valid file: {url}")
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
|
||||||
@@ -499,15 +492,13 @@ class Image(BaseMessageComponent):
|
|||||||
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
"""
|
"""
|
||||||
# convert to base64
|
# convert to base64
|
||||||
url = self.url or self.file
|
url = self.url if self.url else self.file
|
||||||
if not url:
|
if url and url.startswith("file:///"):
|
||||||
raise ValueError("No valid file or URL provided")
|
|
||||||
if url.startswith("file:///"):
|
|
||||||
bs64_data = file_to_base64(url[8:])
|
bs64_data = file_to_base64(url[8:])
|
||||||
elif url.startswith("http"):
|
elif url and url.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(url)
|
image_file_path = await download_image_by_url(url)
|
||||||
bs64_data = file_to_base64(image_file_path)
|
bs64_data = file_to_base64(image_file_path)
|
||||||
elif url.startswith("base64://"):
|
elif url and url.startswith("base64://"):
|
||||||
bs64_data = url
|
bs64_data = url
|
||||||
elif os.path.exists(url):
|
elif os.path.exists(url):
|
||||||
bs64_data = file_to_base64(url)
|
bs64_data = file_to_base64(url)
|
||||||
@@ -541,7 +532,7 @@ class Image(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Reply(BaseMessageComponent):
|
class Reply(BaseMessageComponent):
|
||||||
type = ComponentType.Reply
|
type: ComponentType = "Reply"
|
||||||
id: T.Union[str, int]
|
id: T.Union[str, int]
|
||||||
"""所引用的消息 ID"""
|
"""所引用的消息 ID"""
|
||||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||||
@@ -567,7 +558,7 @@ class Reply(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class RedBag(BaseMessageComponent):
|
class RedBag(BaseMessageComponent):
|
||||||
type = ComponentType.RedBag
|
type: ComponentType = "RedBag"
|
||||||
title: str
|
title: str
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -575,7 +566,7 @@ class RedBag(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Poke(BaseMessageComponent):
|
class Poke(BaseMessageComponent):
|
||||||
type: str = ComponentType.Poke
|
type: str = ""
|
||||||
id: T.Optional[int] = 0
|
id: T.Optional[int] = 0
|
||||||
qq: T.Optional[int] = 0
|
qq: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -585,7 +576,7 @@ class Poke(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Forward(BaseMessageComponent):
|
class Forward(BaseMessageComponent):
|
||||||
type = ComponentType.Forward
|
type: ComponentType = "Forward"
|
||||||
id: str
|
id: str
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -595,7 +586,7 @@ class Forward(BaseMessageComponent):
|
|||||||
class Node(BaseMessageComponent):
|
class Node(BaseMessageComponent):
|
||||||
"""群合并转发消息"""
|
"""群合并转发消息"""
|
||||||
|
|
||||||
type = ComponentType.Node
|
type: ComponentType = "Node"
|
||||||
id: T.Optional[int] = 0 # 忽略
|
id: T.Optional[int] = 0 # 忽略
|
||||||
name: T.Optional[str] = "" # qq昵称
|
name: T.Optional[str] = "" # qq昵称
|
||||||
uin: T.Optional[str] = "0" # qq号
|
uin: T.Optional[str] = "0" # qq号
|
||||||
@@ -647,7 +638,7 @@ class Node(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Nodes(BaseMessageComponent):
|
class Nodes(BaseMessageComponent):
|
||||||
type = ComponentType.Nodes
|
type: ComponentType = "Nodes"
|
||||||
nodes: T.List[Node]
|
nodes: T.List[Node]
|
||||||
|
|
||||||
def __init__(self, nodes: T.List[Node], **_):
|
def __init__(self, nodes: T.List[Node], **_):
|
||||||
@@ -673,7 +664,7 @@ class Nodes(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Xml(BaseMessageComponent):
|
class Xml(BaseMessageComponent):
|
||||||
type = ComponentType.Xml
|
type: ComponentType = "Xml"
|
||||||
data: str
|
data: str
|
||||||
resid: T.Optional[int] = 0
|
resid: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -682,7 +673,7 @@ class Xml(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Json(BaseMessageComponent):
|
class Json(BaseMessageComponent):
|
||||||
type = ComponentType.Json
|
type: ComponentType = "Json"
|
||||||
data: T.Union[str, dict]
|
data: T.Union[str, dict]
|
||||||
resid: T.Optional[int] = 0
|
resid: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -693,7 +684,7 @@ class Json(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class CardImage(BaseMessageComponent):
|
class CardImage(BaseMessageComponent):
|
||||||
type = ComponentType.CardImage
|
type: ComponentType = "CardImage"
|
||||||
file: str
|
file: str
|
||||||
cache: T.Optional[bool] = True
|
cache: T.Optional[bool] = True
|
||||||
minwidth: T.Optional[int] = 400
|
minwidth: T.Optional[int] = 400
|
||||||
@@ -712,7 +703,7 @@ class CardImage(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class TTS(BaseMessageComponent):
|
class TTS(BaseMessageComponent):
|
||||||
type = ComponentType.TTS
|
type: ComponentType = "TTS"
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -720,7 +711,7 @@ class TTS(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Unknown(BaseMessageComponent):
|
class Unknown(BaseMessageComponent):
|
||||||
type = ComponentType.Unknown
|
type: ComponentType = "Unknown"
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
def toString(self):
|
def toString(self):
|
||||||
@@ -732,7 +723,7 @@ class File(BaseMessageComponent):
|
|||||||
文件消息段
|
文件消息段
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type = ComponentType.File
|
type: ComponentType = "File"
|
||||||
name: T.Optional[str] = "" # 名字
|
name: T.Optional[str] = "" # 名字
|
||||||
file_: T.Optional[str] = "" # 本地路径
|
file_: T.Optional[str] = "" # 本地路径
|
||||||
url: T.Optional[str] = "" # url
|
url: T.Optional[str] = "" # url
|
||||||
@@ -862,7 +853,7 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class WechatEmoji(BaseMessageComponent):
|
class WechatEmoji(BaseMessageComponent):
|
||||||
type = ComponentType.WechatEmoji
|
type: ComponentType = "WechatEmoji"
|
||||||
md5: T.Optional[str] = ""
|
md5: T.Optional[str] = ""
|
||||||
md5_len: T.Optional[int] = 0
|
md5_len: T.Optional[int] = 0
|
||||||
cdnurl: T.Optional[str] = ""
|
cdnurl: T.Optional[str] = ""
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
|
|||||||
self.strategy_selector = StrategySelector(config)
|
self.strategy_selector = StrategySelector(config)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent, check_text: str | None = None
|
self, event: AstrMessageEvent, check_text: str = None
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
"""检查内容安全"""
|
"""检查内容安全"""
|
||||||
text = check_text if check_text else event.get_message_str()
|
text = check_text if check_text else event.get_message_str()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
|||||||
self.secret_key = sk
|
self.secret_key = sk
|
||||||
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
||||||
|
|
||||||
def check(self, content: str) -> tuple[bool, str]:
|
def check(self, content: str):
|
||||||
res = self.client.textCensorUserDefined(content)
|
res = self.client.textCensorUserDefined(content)
|
||||||
if "conclusionType" not in res:
|
if "conclusionType" not in res:
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
|
|||||||
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
||||||
# )
|
# )
|
||||||
|
|
||||||
def check(self, content: str) -> tuple[bool, str]:
|
def check(self, content: str) -> bool:
|
||||||
for keyword in self.keywords:
|
for keyword in self.keywords:
|
||||||
if re.search(keyword, content):
|
if re.search(keyword, content):
|
||||||
return False, "内容安全检查不通过,匹配到敏感词。"
|
return False, "内容安全检查不通过,匹配到敏感词。"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
|
|
||||||
async def call_handler(
|
async def call_handler(
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
handler: T.Awaitable,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T.AsyncGenerator[T.Any, None]:
|
) -> T.AsyncGenerator[T.Any, None]:
|
||||||
@@ -36,9 +36,6 @@ async def call_handler(
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||||
|
|
||||||
if not ready_to_call:
|
|
||||||
return
|
|
||||||
|
|
||||||
if inspect.isasyncgen(ready_to_call):
|
if inspect.isasyncgen(ready_to_call):
|
||||||
_has_yielded = False
|
_has_yielded = False
|
||||||
try:
|
try:
|
||||||
@@ -80,7 +77,7 @@ async def call_event_hook(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 如果事件被终止,返回 True
|
bool: 如果事件被终止,返回 True
|
||||||
#"""
|
# """
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
hook_type, plugins_name=event.plugins_name
|
hook_type, plugins_name=event.plugins_name
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import Stage, register_stage
|
from ..stage import Stage, register_stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext
|
||||||
@@ -23,26 +22,6 @@ class PreProcessStage(Stage):
|
|||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
"""在处理事件之前的预处理"""
|
"""在处理事件之前的预处理"""
|
||||||
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
|
|
||||||
supported = {"telegram", "lark"}
|
|
||||||
platform = event.get_platform_name()
|
|
||||||
cfg = (
|
|
||||||
self.config.get("platform_specific", {})
|
|
||||||
.get(platform, {})
|
|
||||||
.get("pre_ack_emoji", {})
|
|
||||||
) or {}
|
|
||||||
emojis = cfg.get("emojis") or []
|
|
||||||
if (
|
|
||||||
cfg.get("enable", False)
|
|
||||||
and platform in supported
|
|
||||||
and emojis
|
|
||||||
and event.is_at_or_wake_command
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
await event.react(random.choice(emojis))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"{platform} 预回应表情发送失败: {e}")
|
|
||||||
|
|
||||||
# 路径映射
|
# 路径映射
|
||||||
if mappings := self.platform_settings.get("path_mapping", []):
|
if mappings := self.platform_settings.get("path_mapping", []):
|
||||||
# 支持 Record,Image 消息段的路径映射。
|
# 支持 Record,Image 消息段的路径映射。
|
||||||
@@ -67,9 +46,6 @@ class PreProcessStage(Stage):
|
|||||||
ctx = self.plugin_manager.context
|
ctx = self.plugin_manager.context
|
||||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||||
if not stt_provider:
|
if not stt_provider:
|
||||||
logger.warning(
|
|
||||||
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
message_chain = event.get_messages()
|
message_chain = event.get_messages()
|
||||||
for idx, component in enumerate(message_chain):
|
for idx, component in enumerate(message_chain):
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from typing import AsyncGenerator, Union
|
from typing import AsyncGenerator, Union
|
||||||
from astrbot.core.conversation_mgr import Conversation
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
@@ -134,15 +133,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
|
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
llm_response = agent_runner.get_final_llm_resp()
|
llm_response = agent_runner.get_final_llm_resp()
|
||||||
|
|
||||||
if not llm_response:
|
|
||||||
text_content = mcp.types.TextContent(
|
|
||||||
type="text",
|
|
||||||
text=f"error when deligate task to {tool.agent.name}",
|
|
||||||
)
|
|
||||||
yield mcp.types.CallToolResult(content=[text_content])
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
||||||
)
|
)
|
||||||
@@ -158,7 +148,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
)
|
)
|
||||||
yield mcp.types.CallToolResult(content=[text_content])
|
yield mcp.types.CallToolResult(content=[text_content])
|
||||||
else:
|
else:
|
||||||
text_content = mcp.types.TextContent(
|
yield mcp.types.TextContent(
|
||||||
type="text",
|
type="text",
|
||||||
text=f"error when deligate task to {tool.agent.name}",
|
text=f"error when deligate task to {tool.agent.name}",
|
||||||
)
|
)
|
||||||
@@ -210,11 +200,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
):
|
):
|
||||||
if not tool.mcp_client:
|
if not tool.mcp_client:
|
||||||
raise ValueError("MCP client is not available for MCP function tools.")
|
raise ValueError("MCP client is not available for MCP function tools.")
|
||||||
|
res = await tool.mcp_client.session.call_tool(
|
||||||
session = tool.mcp_client.session
|
|
||||||
if not session:
|
|
||||||
raise ValueError("MCP session is not available for MCP function tools.")
|
|
||||||
res = await session.call_tool(
|
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool_args,
|
arguments=tool_args,
|
||||||
)
|
)
|
||||||
@@ -285,12 +271,19 @@ async def run_agent(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
astr_event.set_result(
|
||||||
if agent_runner.streaming:
|
MessageEventResult().message(
|
||||||
yield MessageChain().message(err_msg)
|
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||||
else:
|
)
|
||||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
)
|
||||||
return
|
return
|
||||||
|
asyncio.create_task(
|
||||||
|
Metric.upload(
|
||||||
|
llm_tick=1,
|
||||||
|
model_name=agent_runner.provider.get_model(),
|
||||||
|
provider_type=agent_runner.provider.meta().type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMRequestSubStage(Stage):
|
class LLMRequestSubStage(Stage):
|
||||||
@@ -306,9 +299,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
self.max_context_length - 1,
|
self.max_context_length - 1,
|
||||||
)
|
)
|
||||||
self.streaming_response: bool = settings["streaming_response"]
|
self.streaming_response: bool = settings["streaming_response"]
|
||||||
self.max_step: int = settings.get("max_agent_step", 30)
|
self.max_step: int = settings.get("max_agent_step", 10)
|
||||||
if isinstance(self.max_step, bool): # workaround: #2622
|
|
||||||
self.max_step = 30
|
|
||||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||||
|
|
||||||
for bwp in self.bot_wake_prefixs:
|
for bwp in self.bot_wake_prefixs:
|
||||||
@@ -332,7 +323,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||||
|
|
||||||
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
async def _get_session_conv(self, event: AstrMessageEvent):
|
||||||
umo = event.unified_msg_origin
|
umo = event.unified_msg_origin
|
||||||
conv_mgr = self.conv_manager
|
conv_mgr = self.conv_manager
|
||||||
|
|
||||||
@@ -344,8 +335,6 @@ class LLMRequestSubStage(Stage):
|
|||||||
if not conversation:
|
if not conversation:
|
||||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||||
if not conversation:
|
|
||||||
raise RuntimeError("无法创建新的对话。")
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
@@ -445,18 +434,13 @@ class LLMRequestSubStage(Stage):
|
|||||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||||
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
||||||
if "tool_use" not in provider_cfg:
|
if "tool_use" not in provider_cfg:
|
||||||
logger.debug(
|
logger.debug(f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。")
|
||||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
|
|
||||||
)
|
|
||||||
req.func_tool = None
|
req.func_tool = None
|
||||||
# 插件可用性设置
|
# 插件可用性设置
|
||||||
if event.plugins_name is not None and req.func_tool:
|
if event.plugins_name is not None and req.func_tool:
|
||||||
new_tool_set = ToolSet()
|
new_tool_set = ToolSet()
|
||||||
for tool in req.func_tool.tools:
|
for tool in req.func_tool.tools:
|
||||||
mp = tool.handler_module_path
|
plugin = star_map.get(tool.handler_module_path)
|
||||||
if not mp:
|
|
||||||
continue
|
|
||||||
plugin = star_map.get(mp)
|
|
||||||
if not plugin:
|
if not plugin:
|
||||||
continue
|
continue
|
||||||
if plugin.name in event.plugins_name or plugin.reserved:
|
if plugin.name in event.plugins_name or plugin.reserved:
|
||||||
@@ -517,14 +501,6 @@ class LLMRequestSubStage(Stage):
|
|||||||
if event.get_platform_name() == "webchat":
|
if event.get_platform_name() == "webchat":
|
||||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||||
|
|
||||||
asyncio.create_task(
|
|
||||||
Metric.upload(
|
|
||||||
llm_tick=1,
|
|
||||||
model_name=agent_runner.provider.get_model(),
|
|
||||||
provider_type=agent_runner.provider.meta().type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_webchat(
|
async def _handle_webchat(
|
||||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||||
):
|
):
|
||||||
@@ -537,23 +513,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
latest_pair = messages[-2:]
|
latest_pair = messages[-2:]
|
||||||
if not latest_pair:
|
if not latest_pair:
|
||||||
return
|
return
|
||||||
content = latest_pair[0].get("content", "")
|
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||||
if isinstance(content, list):
|
|
||||||
# 多模态
|
|
||||||
text_parts = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
if item.get("type") == "text":
|
|
||||||
text_parts.append(item.get("text", ""))
|
|
||||||
elif item.get("type") == "image":
|
|
||||||
text_parts.append("[图片]")
|
|
||||||
elif isinstance(item, str):
|
|
||||||
text_parts.append(item)
|
|
||||||
cleaned_text = "User: " + " ".join(text_parts).strip()
|
|
||||||
elif isinstance(content, str):
|
|
||||||
cleaned_text = "User: " + content.strip()
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||||
llm_resp = await prov.text_chat(
|
llm_resp = await prov.text_chat(
|
||||||
system_prompt="You are expert in summarizing user's query.",
|
system_prompt="You are expert in summarizing user's query.",
|
||||||
|
|||||||
@@ -34,14 +34,12 @@ class StarRequestSubStage(Stage):
|
|||||||
|
|
||||||
for handler in activated_handlers:
|
for handler in activated_handlers:
|
||||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||||
md = star_map.get(handler.handler_module_path)
|
|
||||||
if not md:
|
|
||||||
logger.warning(
|
|
||||||
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
|
||||||
try:
|
try:
|
||||||
|
if handler.handler_module_path not in star_map:
|
||||||
|
continue
|
||||||
|
logger.debug(
|
||||||
|
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
wrapper = call_handler(event, handler.handler, **params)
|
wrapper = call_handler(event, handler.handler, **params)
|
||||||
async for ret in wrapper:
|
async for ret in wrapper:
|
||||||
yield ret
|
yield ret
|
||||||
@@ -51,7 +49,7 @@ class StarRequestSubStage(Stage):
|
|||||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||||
|
|
||||||
if event.is_at_or_wake_command:
|
if event.is_at_or_wake_command:
|
||||||
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||||
event.set_result(MessageEventResult().message(ret))
|
event.set_result(MessageEventResult().message(ret))
|
||||||
yield
|
yield
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
|
import traceback
|
||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import register_stage, Stage
|
from ..stage import register_stage, Stage
|
||||||
from ..context import PipelineContext, call_event_hook
|
from ..context import PipelineContext
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import BaseMessageComponent, ComponentType
|
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||||
from astrbot.core.star.star_handler import EventType
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.utils.path_util import path_Mapping
|
from astrbot.core.utils.path_util import path_Mapping
|
||||||
from astrbot.core.utils.session_lock import session_lock_manager
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
@@ -112,43 +114,6 @@ class RespondStage(Stage):
|
|||||||
# 如果所有组件都为空
|
# 如果所有组件都为空
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
|
|
||||||
"""检查是否需要分段回复"""
|
|
||||||
if not self.enable_seg:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
|
||||||
return False
|
|
||||||
|
|
||||||
if event.get_platform_name() in [
|
|
||||||
"qq_official",
|
|
||||||
"weixin_official_account",
|
|
||||||
"dingtalk",
|
|
||||||
]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _extract_comp(
|
|
||||||
self,
|
|
||||||
raw_chain: list[BaseMessageComponent],
|
|
||||||
extract_types: set[ComponentType],
|
|
||||||
modify_raw_chain: bool = True,
|
|
||||||
):
|
|
||||||
extracted = []
|
|
||||||
if modify_raw_chain:
|
|
||||||
remaining = []
|
|
||||||
for comp in raw_chain:
|
|
||||||
if comp.type in extract_types:
|
|
||||||
extracted.append(comp)
|
|
||||||
else:
|
|
||||||
remaining.append(comp)
|
|
||||||
raw_chain[:] = remaining
|
|
||||||
else:
|
|
||||||
extracted = [comp for comp in raw_chain if comp.type in extract_types]
|
|
||||||
|
|
||||||
return extracted
|
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
@@ -158,14 +123,7 @@ class RespondStage(Stage):
|
|||||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
if result.async_stream is None:
|
|
||||||
logger.warning("async_stream 为空,跳过发送。")
|
|
||||||
return
|
|
||||||
# 流式结果直接交付平台适配器处理
|
# 流式结果直接交付平台适配器处理
|
||||||
use_fallback = self.config.get("provider_settings", {}).get(
|
use_fallback = self.config.get("provider_settings", {}).get(
|
||||||
"streaming_segmented", False
|
"streaming_segmented", False
|
||||||
@@ -190,71 +148,87 @@ class RespondStage(Stage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"空内容检查异常: {e}")
|
logger.warning(f"空内容检查异常: {e}")
|
||||||
|
|
||||||
# 发送消息链
|
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
||||||
# Record 需要强制单独发送
|
non_record_comps = [
|
||||||
need_separately = {ComponentType.Record}
|
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||||
if self.is_seg_reply_required(event):
|
]
|
||||||
header_comps = self._extract_comp(
|
|
||||||
result.chain,
|
if (
|
||||||
{ComponentType.Reply, ComponentType.At},
|
self.enable_seg
|
||||||
modify_raw_chain=True,
|
and (
|
||||||
|
(self.only_llm_result and result.is_llm_result())
|
||||||
|
or not self.only_llm_result
|
||||||
)
|
)
|
||||||
if not result.chain or len(result.chain) == 0:
|
and event.get_platform_name()
|
||||||
# may fix #2670
|
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
||||||
logger.warning(
|
):
|
||||||
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
|
decorated_comps = []
|
||||||
)
|
if self.reply_with_mention:
|
||||||
return
|
|
||||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Comp.At):
|
||||||
|
decorated_comps.append(comp)
|
||||||
|
result.chain.remove(comp)
|
||||||
|
break
|
||||||
|
if self.reply_with_quote:
|
||||||
|
for comp in result.chain:
|
||||||
|
if isinstance(comp, Comp.Reply):
|
||||||
|
decorated_comps.append(comp)
|
||||||
|
result.chain.remove(comp)
|
||||||
|
break
|
||||||
|
|
||||||
|
# leverage lock to guarentee the order of message sending among different events
|
||||||
|
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||||
|
for rcomp in record_comps:
|
||||||
|
i = await self._calc_comp_interval(rcomp)
|
||||||
|
await asyncio.sleep(i)
|
||||||
|
try:
|
||||||
|
await event.send(MessageChain([rcomp]))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
break
|
||||||
|
# 分段回复
|
||||||
|
for comp in non_record_comps:
|
||||||
i = await self._calc_comp_interval(comp)
|
i = await self._calc_comp_interval(comp)
|
||||||
await asyncio.sleep(i)
|
await asyncio.sleep(i)
|
||||||
try:
|
try:
|
||||||
if comp.type in need_separately:
|
await event.send(MessageChain([*decorated_comps, comp]))
|
||||||
await event.send(MessageChain([comp]))
|
decorated_comps = [] # 清空已发送的装饰组件
|
||||||
else:
|
|
||||||
await event.send(MessageChain([*header_comps, comp]))
|
|
||||||
header_comps.clear()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
break
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if all(
|
for rcomp in record_comps:
|
||||||
comp.type in {ComponentType.Reply, ComponentType.At}
|
|
||||||
for comp in result.chain
|
|
||||||
):
|
|
||||||
# may fix #2670
|
|
||||||
logger.warning(
|
|
||||||
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
sep_comps = self._extract_comp(
|
|
||||||
result.chain,
|
|
||||||
need_separately,
|
|
||||||
modify_raw_chain=True,
|
|
||||||
)
|
|
||||||
for comp in sep_comps:
|
|
||||||
chain = MessageChain([comp])
|
|
||||||
try:
|
try:
|
||||||
await event.send(chain)
|
await event.send(MessageChain([rcomp]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
f"发送消息链失败: chain = {chain}, error = {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
chain = MessageChain(result.chain)
|
|
||||||
if result.chain and len(result.chain) > 0:
|
|
||||||
try:
|
|
||||||
await event.send(chain)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"发送消息链失败: chain = {chain}, error = {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
|
try:
|
||||||
return
|
await event.send(MessageChain(non_record_comps))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
||||||
|
)
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
await handler.handler(event)
|
||||||
|
except BaseException:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
if event.is_stopped():
|
||||||
|
logger.info(
|
||||||
|
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ class ResultDecorateStage(Stage):
|
|||||||
self.t2i_word_threshold = 150
|
self.t2i_word_threshold = 150
|
||||||
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
||||||
self.t2i_use_network = self.t2i_strategy == "remote"
|
self.t2i_use_network = self.t2i_strategy == "remote"
|
||||||
self.t2i_active_template = ctx.astrbot_config["t2i_active_template"]
|
|
||||||
|
|
||||||
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
||||||
"forward_threshold"
|
"forward_threshold"
|
||||||
@@ -183,13 +182,9 @@ class ResultDecorateStage(Stage):
|
|||||||
if (
|
if (
|
||||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||||
and result.is_llm_result()
|
and result.is_llm_result()
|
||||||
|
and tts_provider
|
||||||
and SessionServiceManager.should_process_tts_request(event)
|
and SessionServiceManager.should_process_tts_request(event)
|
||||||
):
|
):
|
||||||
if not tts_provider:
|
|
||||||
logger.warning(
|
|
||||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
new_chain = []
|
new_chain = []
|
||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||||
@@ -252,10 +247,7 @@ class ResultDecorateStage(Stage):
|
|||||||
render_start = time.time()
|
render_start = time.time()
|
||||||
try:
|
try:
|
||||||
url = await html_renderer.render_t2i(
|
url = await html_renderer.render_t2i(
|
||||||
plain_str,
|
plain_str, return_url=True, use_network=self.t2i_use_network
|
||||||
return_url=True,
|
|
||||||
use_network=self.t2i_use_network,
|
|
||||||
template_name=self.t2i_active_template,
|
|
||||||
)
|
)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error("文本转图片失败,使用文本发送。")
|
logger.error("文本转图片失败,使用文本发送。")
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ class SessionStatusCheckStage(Stage):
|
|||||||
"""检查会话是否整体启用"""
|
"""检查会话是否整体启用"""
|
||||||
|
|
||||||
async def initialize(self, ctx: PipelineContext) -> None:
|
async def initialize(self, ctx: PipelineContext) -> None:
|
||||||
self.ctx = ctx
|
pass
|
||||||
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
|
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
@@ -20,14 +19,4 @@ class SessionStatusCheckStage(Stage):
|
|||||||
# 检查会话是否整体启用
|
# 检查会话是否整体启用
|
||||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||||
|
|
||||||
# workaround for #2309
|
|
||||||
conv_id = await self.conv_mgr.get_curr_conversation_id(
|
|
||||||
event.unified_msg_origin
|
|
||||||
)
|
|
||||||
if not conv_id:
|
|
||||||
await self.conv_mgr.new_conversation(
|
|
||||||
event.unified_msg_origin, platform_id=event.get_platform_id()
|
|
||||||
)
|
|
||||||
|
|
||||||
event.stop_event()
|
event.stop_event()
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from astrbot.core.message.components import At, AtAll, Reply
|
|||||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
|
||||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||||
@@ -171,15 +170,11 @@ class WakingCheckStage(Stage):
|
|||||||
is_wake = True
|
is_wake = True
|
||||||
event.is_wake = True
|
event.is_wake = True
|
||||||
|
|
||||||
is_group_cmd_handler = any(
|
activated_handlers.append(handler)
|
||||||
isinstance(f, CommandGroupFilter) for f in handler.event_filters
|
if "parsed_params" in event.get_extra():
|
||||||
)
|
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||||
if not is_group_cmd_handler:
|
"parsed_params"
|
||||||
activated_handlers.append(handler)
|
)
|
||||||
if "parsed_params" in event.get_extra(default={}):
|
|
||||||
handlers_parsed_params[handler.handler_full_name] = (
|
|
||||||
event.get_extra("parsed_params")
|
|
||||||
)
|
|
||||||
|
|
||||||
event._extras.pop("parsed_params", None)
|
event._extras.pop("parsed_params", None)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import re
|
|||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
from typing import List, Union, Optional, AsyncGenerator
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.db.po import Conversation
|
from astrbot.core.db.po import Conversation
|
||||||
@@ -24,9 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
|
|||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from .astrbot_message import AstrBotMessage, Group
|
from .astrbot_message import AstrBotMessage, Group
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
from .message_session import MessageSession, MessageSesion # noqa
|
from .message_session import MessageSession, MessageSesion # noqa
|
||||||
|
|
||||||
_VT = TypeVar("_VT")
|
|
||||||
|
|
||||||
|
|
||||||
class AstrMessageEvent(abc.ABC):
|
class AstrMessageEvent(abc.ABC):
|
||||||
@@ -51,7 +49,7 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
"""是否唤醒(是否通过 WakingStage)"""
|
"""是否唤醒(是否通过 WakingStage)"""
|
||||||
self.is_at_or_wake_command = False
|
self.is_at_or_wake_command = False
|
||||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||||
self._extras: dict[str, Any] = {}
|
self._extras = {}
|
||||||
self.session = MessageSesion(
|
self.session = MessageSesion(
|
||||||
platform_name=platform_meta.id,
|
platform_name=platform_meta.id,
|
||||||
message_type=message_obj.type,
|
message_type=message_obj.type,
|
||||||
@@ -59,7 +57,7 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
)
|
)
|
||||||
self.unified_msg_origin = str(self.session)
|
self.unified_msg_origin = str(self.session)
|
||||||
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||||
self._result: MessageEventResult | None = None
|
self._result: MessageEventResult = None
|
||||||
"""消息事件的结果"""
|
"""消息事件的结果"""
|
||||||
|
|
||||||
self._has_send_oper = False
|
self._has_send_oper = False
|
||||||
@@ -175,15 +173,13 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
self._extras[key] = value
|
self._extras[key] = value
|
||||||
|
|
||||||
def get_extra(
|
def get_extra(self, key=None):
|
||||||
self, key: str | None = None, default: _VT = None
|
|
||||||
) -> dict[str, Any] | _VT:
|
|
||||||
"""
|
"""
|
||||||
获取额外的信息。
|
获取额外的信息。
|
||||||
"""
|
"""
|
||||||
if key is None:
|
if key is None:
|
||||||
return self._extras
|
return self._extras
|
||||||
return self._extras.get(key, default)
|
return self._extras.get(key, None)
|
||||||
|
|
||||||
def clear_extra(self):
|
def clear_extra(self):
|
||||||
"""
|
"""
|
||||||
@@ -416,16 +412,6 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
)
|
)
|
||||||
self._has_send_oper = True
|
self._has_send_oper = True
|
||||||
|
|
||||||
async def react(self, emoji: str):
|
|
||||||
"""
|
|
||||||
对消息添加表情回应。
|
|
||||||
|
|
||||||
默认实现为发送一条包含该表情的消息。
|
|
||||||
注意:此实现并不一定符合所有平台的原生“表情回应”行为。
|
|
||||||
如需支持平台原生的消息反应功能,请在对应平台的子类中重写本方法。
|
|
||||||
"""
|
|
||||||
await self.send(MessageChain([Plain(emoji)]))
|
|
||||||
|
|
||||||
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
||||||
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class AstrBotMessage:
|
|||||||
self_id: str # 机器人的识别id
|
self_id: str # 机器人的识别id
|
||||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||||
message_id: str # 消息id
|
message_id: str # 消息id
|
||||||
group: Group # 群组
|
group_id: str = "" # 群组id,如果为私聊,则为空
|
||||||
sender: MessageMember # 发送者
|
sender: MessageMember # 发送者
|
||||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||||
message_str: str # 最直观的纯文本消息字符串
|
message_str: str # 最直观的纯文本消息字符串
|
||||||
@@ -64,28 +64,6 @@ class AstrBotMessage:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.timestamp = int(time.time())
|
self.timestamp = int(time.time())
|
||||||
self.group = None
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return str(self.__dict__)
|
return str(self.__dict__)
|
||||||
|
|
||||||
@property
|
|
||||||
def group_id(self) -> str:
|
|
||||||
"""
|
|
||||||
向后兼容的 group_id 属性
|
|
||||||
群组id,如果为私聊,则为空
|
|
||||||
"""
|
|
||||||
if self.group:
|
|
||||||
return self.group.group_id
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@group_id.setter
|
|
||||||
def group_id(self, value: str):
|
|
||||||
"""设置 group_id"""
|
|
||||||
if value:
|
|
||||||
if self.group:
|
|
||||||
self.group.group_id = value
|
|
||||||
else:
|
|
||||||
self.group = Group(group_id=value)
|
|
||||||
else:
|
|
||||||
self.group = None
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import List
|
|||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from .register import platform_cls_map
|
from .register import platform_cls_map
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType
|
|
||||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||||
|
|
||||||
|
|
||||||
@@ -67,39 +66,25 @@ class PlatformManager:
|
|||||||
WeChatPadProAdapter, # noqa: F401
|
WeChatPadProAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "lark":
|
case "lark":
|
||||||
from .sources.lark.lark_adapter import (
|
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
||||||
LarkPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
case "dingtalk":
|
case "dingtalk":
|
||||||
from .sources.dingtalk.dingtalk_adapter import (
|
from .sources.dingtalk.dingtalk_adapter import (
|
||||||
DingtalkPlatformAdapter, # noqa: F401
|
DingtalkPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "telegram":
|
case "telegram":
|
||||||
from .sources.telegram.tg_adapter import (
|
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
||||||
TelegramPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
case "wecom":
|
case "wecom":
|
||||||
from .sources.wecom.wecom_adapter import (
|
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
||||||
WecomPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
case "weixin_official_account":
|
case "weixin_official_account":
|
||||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||||
WeixinOfficialAccountPlatformAdapter, # noqa: F401
|
WeixinOfficialAccountPlatformAdapter, # noqa
|
||||||
)
|
)
|
||||||
case "discord":
|
case "discord":
|
||||||
from .sources.discord.discord_platform_adapter import (
|
from .sources.discord.discord_platform_adapter import (
|
||||||
DiscordPlatformAdapter, # noqa: F401
|
DiscordPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "misskey":
|
|
||||||
from .sources.misskey.misskey_adapter import (
|
|
||||||
MisskeyPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
case "slack":
|
case "slack":
|
||||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||||
case "satori":
|
|
||||||
from .sources.satori.satori_adapter import (
|
|
||||||
SatoriPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||||
@@ -128,17 +113,6 @@ class PlatformManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
|
||||||
EventType.OnPlatformLoadedEvent
|
|
||||||
)
|
|
||||||
for handler in handlers:
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
|
||||||
)
|
|
||||||
await handler.handler()
|
|
||||||
except Exception:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
async def _task_wrapper(self, task: asyncio.Task):
|
async def _task_wrapper(self, task: asyncio.Task):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,5 +14,3 @@ class PlatformMetadata:
|
|||||||
"""平台的默认配置模板"""
|
"""平台的默认配置模板"""
|
||||||
adapter_display_name: str = None
|
adapter_display_name: str = None
|
||||||
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||||
logo_path: str = None
|
|
||||||
"""平台适配器的 logo 文件路径(相对于插件目录)"""
|
|
||||||
|
|||||||
@@ -13,12 +13,10 @@ def register_platform_adapter(
|
|||||||
desc: str,
|
desc: str,
|
||||||
default_config_tmpl: dict = None,
|
default_config_tmpl: dict = None,
|
||||||
adapter_display_name: str = None,
|
adapter_display_name: str = None,
|
||||||
logo_path: str = None,
|
|
||||||
):
|
):
|
||||||
"""用于注册平台适配器的带参装饰器。
|
"""用于注册平台适配器的带参装饰器。
|
||||||
|
|
||||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||||
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
@@ -41,7 +39,6 @@ def register_platform_adapter(
|
|||||||
description=desc,
|
description=desc,
|
||||||
default_config_tmpl=default_config_tmpl,
|
default_config_tmpl=default_config_tmpl,
|
||||||
adapter_display_name=adapter_display_name,
|
adapter_display_name=adapter_display_name,
|
||||||
logo_path=logo_path,
|
|
||||||
)
|
)
|
||||||
platform_registry.append(pm)
|
platform_registry.append(pm)
|
||||||
platform_cls_map[adapter_name] = cls
|
platform_cls_map[adapter_name] = cls
|
||||||
|
|||||||
@@ -67,19 +67,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
):
|
):
|
||||||
# session_id 必须是纯数字字符串
|
if event:
|
||||||
session_id = int(session_id) if session_id.isdigit() else None
|
|
||||||
|
|
||||||
if is_group and isinstance(session_id, int):
|
|
||||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
|
||||||
elif not is_group and isinstance(session_id, int):
|
|
||||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
|
||||||
elif isinstance(event, Event): # 最后兜底
|
|
||||||
await bot.send(event=event, message=messages)
|
await bot.send(event=event, message=messages)
|
||||||
|
elif is_group:
|
||||||
|
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||||
f"无法发送消息:缺少有效的数字 session_id({session_id}) 或 event({event})"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def send_message(
|
async def send_message(
|
||||||
@@ -90,15 +83,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
is_group: bool = False,
|
is_group: bool = False,
|
||||||
session_id: str = None,
|
session_id: str = None,
|
||||||
):
|
):
|
||||||
"""发送消息至 QQ 协议端(aiocqhttp)。
|
"""发送消息"""
|
||||||
|
|
||||||
Args:
|
|
||||||
bot (CQHttp): aiocqhttp 机器人实例
|
|
||||||
message_chain (MessageChain): 要发送的消息链
|
|
||||||
event (Event | None, optional): aiocqhttp 事件对象.
|
|
||||||
is_group (bool, optional): 是否为群消息.
|
|
||||||
session_id (str | None, optional): 会话 ID(群号或 QQ 号
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 转发消息、文件消息不能和普通消息混在一起发送
|
# 转发消息、文件消息不能和普通消息混在一起发送
|
||||||
send_one_by_one = any(
|
send_one_by_one = any(
|
||||||
@@ -137,15 +122,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
"""发送消息"""
|
"""发送消息"""
|
||||||
event = getattr(self.message_obj, "raw_message", None)
|
event = self.message_obj.raw_message
|
||||||
|
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
|
||||||
is_group = bool(self.get_group_id())
|
is_group = False
|
||||||
session_id = self.get_group_id() if is_group else self.get_sender_id()
|
if self.get_group_id():
|
||||||
|
is_group = True
|
||||||
|
session_id = self.get_group_id()
|
||||||
|
else:
|
||||||
|
session_id = self.get_sender_id()
|
||||||
await self.send_message(
|
await self.send_message(
|
||||||
bot=self.bot,
|
bot=self.bot,
|
||||||
message_chain=message,
|
message_chain=message,
|
||||||
event=event, # 不强制要求一定是 Event
|
event=event,
|
||||||
is_group=is_group,
|
is_group=is_group,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -182,13 +182,11 @@ class AiocqhttpAdapter(Platform):
|
|||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
str(event.sender["user_id"]),
|
str(event.sender["user_id"]), event.sender["nickname"]
|
||||||
event.sender.get("card") or event.sender.get("nickname", "N/A"),
|
|
||||||
)
|
)
|
||||||
if event["message_type"] == "group":
|
if event["message_type"] == "group":
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
abm.group_id = str(event.group_id)
|
abm.group_id = str(event.group_id)
|
||||||
abm.group.group_name = event.get("group_name", "N/A")
|
|
||||||
elif event["message_type"] == "private":
|
elif event["message_type"] == "private":
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
@@ -310,22 +308,13 @@ class AiocqhttpAdapter(Platform):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
at_info = await self.bot.call_action(
|
at_info = await self.bot.call_action(
|
||||||
action="get_group_member_info",
|
action="get_stranger_info",
|
||||||
group_id=event.group_id,
|
|
||||||
user_id=int(m["data"]["qq"]),
|
user_id=int(m["data"]["qq"]),
|
||||||
no_cache=False,
|
|
||||||
)
|
)
|
||||||
if at_info:
|
if at_info:
|
||||||
nickname = at_info.get("card", "")
|
nickname = at_info.get("nick", "") or at_info.get(
|
||||||
if nickname == "":
|
"nickname", ""
|
||||||
at_info = await self.bot.call_action(
|
)
|
||||||
action="get_stranger_info",
|
|
||||||
user_id=int(m["data"]["qq"]),
|
|
||||||
no_cache=False,
|
|
||||||
)
|
|
||||||
nickname = at_info.get("nick", "") or at_info.get(
|
|
||||||
"nickname", ""
|
|
||||||
)
|
|
||||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||||
|
|
||||||
abm.message.append(
|
abm.message.append(
|
||||||
|
|||||||
@@ -54,9 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
|||||||
logger.debug(f"send image: {ret}")
|
logger.debug(f"send image: {ret}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
|
logger.error(f"钉钉图片处理失败: {e}")
|
||||||
|
logger.warning(f"跳过图片发送: {image_path}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
await self.send_with_client(self.client, message)
|
await self.send_with_client(self.client, message)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|||||||
@@ -41,8 +41,7 @@ class DiscordBotClient(discord.Bot):
|
|||||||
await self.on_ready_once_callback()
|
await self.on_ready_once_callback()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
|
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
||||||
)
|
|
||||||
|
|
||||||
def _create_message_data(self, message: discord.Message) -> dict:
|
def _create_message_data(self, message: discord.Message) -> dict:
|
||||||
"""从 discord.Message 创建数据字典"""
|
"""从 discord.Message 创建数据字典"""
|
||||||
@@ -91,6 +90,7 @@ class DiscordBotClient(discord.Bot):
|
|||||||
message_data = self._create_message_data(message)
|
message_data = self._create_message_data(message)
|
||||||
await self.on_message_received(message_data)
|
await self.on_message_received(message_data)
|
||||||
|
|
||||||
|
|
||||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||||
"""从交互中提取内容"""
|
"""从交互中提取内容"""
|
||||||
interaction_type = interaction.type
|
interaction_type = interaction.type
|
||||||
|
|||||||
@@ -79,12 +79,9 @@ class DiscordButton(BaseMessageComponent):
|
|||||||
self.url = url
|
self.url = url
|
||||||
self.disabled = disabled
|
self.disabled = disabled
|
||||||
|
|
||||||
|
|
||||||
class DiscordReference(BaseMessageComponent):
|
class DiscordReference(BaseMessageComponent):
|
||||||
"""Discord引用组件"""
|
"""Discord引用组件"""
|
||||||
|
|
||||||
type: str = "discord_reference"
|
type: str = "discord_reference"
|
||||||
|
|
||||||
def __init__(self, message_id: str, channel_id: str):
|
def __init__(self, message_id: str, channel_id: str):
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
self.channel_id = channel_id
|
self.channel_id = channel_id
|
||||||
@@ -101,6 +98,7 @@ class DiscordView(BaseMessageComponent):
|
|||||||
self.components = components or []
|
self.components = components or []
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
|
||||||
def to_discord_view(self) -> discord.ui.View:
|
def to_discord_view(self) -> discord.ui.View:
|
||||||
"""转换为Discord View对象"""
|
"""转换为Discord View对象"""
|
||||||
view = discord.ui.View(timeout=self.timeout)
|
view = discord.ui.View(timeout=self.timeout)
|
||||||
|
|||||||
@@ -53,13 +53,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 解析消息链为 Discord 所需的对象
|
# 解析消息链为 Discord 所需的对象
|
||||||
try:
|
try:
|
||||||
(
|
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
||||||
content,
|
|
||||||
files,
|
|
||||||
view,
|
|
||||||
embeds,
|
|
||||||
reference_message_id,
|
|
||||||
) = await self._parse_to_discord(message)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||||
return
|
return
|
||||||
@@ -212,7 +206,8 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
if await asyncio.to_thread(path.exists):
|
if await asyncio.to_thread(path.exists):
|
||||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||||
files.append(
|
files.append(
|
||||||
discord.File(BytesIO(file_bytes), filename=i.name)
|
discord.File(BytesIO(file_bytes),
|
||||||
|
filename=i.name)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -107,22 +107,6 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
async def react(self, emoji: str):
|
|
||||||
request = (
|
|
||||||
CreateMessageReactionRequest.builder()
|
|
||||||
.message_id(self.message_obj.message_id)
|
|
||||||
.request_body(
|
|
||||||
CreateMessageReactionRequestBody.builder()
|
|
||||||
.reaction_type(Emoji.builder().emoji_type(emoji).build())
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
|
||||||
if not response.success():
|
|
||||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
buffer = None
|
buffer = None
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
|
|||||||
@@ -1,391 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from typing import Dict, Any, Optional, Awaitable
|
|
||||||
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.api.event import MessageChain
|
|
||||||
from astrbot.api.platform import (
|
|
||||||
AstrBotMessage,
|
|
||||||
Platform,
|
|
||||||
PlatformMetadata,
|
|
||||||
register_platform_adapter,
|
|
||||||
)
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSession
|
|
||||||
import astrbot.api.message_components as Comp
|
|
||||||
|
|
||||||
from .misskey_api import MisskeyAPI
|
|
||||||
from .misskey_event import MisskeyPlatformEvent
|
|
||||||
from .misskey_utils import (
|
|
||||||
serialize_message_chain,
|
|
||||||
resolve_message_visibility,
|
|
||||||
is_valid_user_session_id,
|
|
||||||
is_valid_room_session_id,
|
|
||||||
add_at_mention_if_needed,
|
|
||||||
process_files,
|
|
||||||
extract_sender_info,
|
|
||||||
create_base_message,
|
|
||||||
process_at_mention,
|
|
||||||
cache_user_info,
|
|
||||||
cache_room_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
|
||||||
class MisskeyPlatformAdapter(Platform):
|
|
||||||
def __init__(
|
|
||||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
|
||||||
) -> None:
|
|
||||||
super().__init__(event_queue)
|
|
||||||
self.config = platform_config or {}
|
|
||||||
self.settings = platform_settings or {}
|
|
||||||
self.instance_url = self.config.get("misskey_instance_url", "")
|
|
||||||
self.access_token = self.config.get("misskey_token", "")
|
|
||||||
self.max_message_length = self.config.get("max_message_length", 3000)
|
|
||||||
self.default_visibility = self.config.get(
|
|
||||||
"misskey_default_visibility", "public"
|
|
||||||
)
|
|
||||||
self.local_only = self.config.get("misskey_local_only", False)
|
|
||||||
self.enable_chat = self.config.get("misskey_enable_chat", True)
|
|
||||||
|
|
||||||
self.unique_session = platform_settings["unique_session"]
|
|
||||||
|
|
||||||
self.api: Optional[MisskeyAPI] = None
|
|
||||||
self._running = False
|
|
||||||
self.client_self_id = ""
|
|
||||||
self._bot_username = ""
|
|
||||||
self._user_cache = {}
|
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
|
||||||
default_config = {
|
|
||||||
"misskey_instance_url": "",
|
|
||||||
"misskey_token": "",
|
|
||||||
"max_message_length": 3000,
|
|
||||||
"misskey_default_visibility": "public",
|
|
||||||
"misskey_local_only": False,
|
|
||||||
"misskey_enable_chat": True,
|
|
||||||
}
|
|
||||||
default_config.update(self.config)
|
|
||||||
|
|
||||||
return PlatformMetadata(
|
|
||||||
name="misskey",
|
|
||||||
description="Misskey 平台适配器",
|
|
||||||
id=self.config.get("id", "misskey"),
|
|
||||||
default_config_tmpl=default_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
if not self.instance_url or not self.access_token:
|
|
||||||
logger.error("[Misskey] 配置不完整,无法启动")
|
|
||||||
return
|
|
||||||
|
|
||||||
self.api = MisskeyAPI(self.instance_url, self.access_token)
|
|
||||||
self._running = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
user_info = await self.api.get_current_user()
|
|
||||||
self.client_self_id = str(user_info.get("id", ""))
|
|
||||||
self._bot_username = user_info.get("username", "")
|
|
||||||
logger.info(
|
|
||||||
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Misskey] 获取用户信息失败: {e}")
|
|
||||||
self._running = False
|
|
||||||
return
|
|
||||||
|
|
||||||
await self._start_websocket_connection()
|
|
||||||
|
|
||||||
async def _start_websocket_connection(self):
|
|
||||||
backoff_delay = 1.0
|
|
||||||
max_backoff = 300.0
|
|
||||||
backoff_multiplier = 1.5
|
|
||||||
connection_attempts = 0
|
|
||||||
|
|
||||||
while self._running:
|
|
||||||
try:
|
|
||||||
connection_attempts += 1
|
|
||||||
if not self.api:
|
|
||||||
logger.error("[Misskey] API 客户端未初始化")
|
|
||||||
break
|
|
||||||
|
|
||||||
streaming = self.api.get_streaming_client()
|
|
||||||
streaming.add_message_handler("notification", self._handle_notification)
|
|
||||||
if self.enable_chat:
|
|
||||||
streaming.add_message_handler(
|
|
||||||
"newChatMessage", self._handle_chat_message
|
|
||||||
)
|
|
||||||
streaming.add_message_handler("_debug", self._debug_handler)
|
|
||||||
|
|
||||||
if await streaming.connect():
|
|
||||||
logger.info(
|
|
||||||
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
|
|
||||||
)
|
|
||||||
connection_attempts = 0 # 重置计数器
|
|
||||||
await streaming.subscribe_channel("main")
|
|
||||||
if self.enable_chat:
|
|
||||||
await streaming.subscribe_channel("messaging")
|
|
||||||
await streaming.subscribe_channel("messagingIndex")
|
|
||||||
logger.info("[Misskey] 聊天频道已订阅")
|
|
||||||
|
|
||||||
backoff_delay = 1.0 # 重置延迟
|
|
||||||
await streaming.listen()
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._running:
|
|
||||||
logger.info(
|
|
||||||
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(backoff_delay)
|
|
||||||
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
|
|
||||||
|
|
||||||
async def _handle_notification(self, data: Dict[str, Any]):
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
|
||||||
)
|
|
||||||
notification_type = data.get("type")
|
|
||||||
if notification_type in ["mention", "reply", "quote"]:
|
|
||||||
note = data.get("note")
|
|
||||||
if note and self._is_bot_mentioned(note):
|
|
||||||
logger.info(
|
|
||||||
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
|
|
||||||
)
|
|
||||||
message = await self.convert_message(note)
|
|
||||||
event = MisskeyPlatformEvent(
|
|
||||||
message_str=message.message_str,
|
|
||||||
message_obj=message,
|
|
||||||
platform_meta=self.meta(),
|
|
||||||
session_id=message.session_id,
|
|
||||||
client=self.api,
|
|
||||||
)
|
|
||||||
self.commit_event(event)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Misskey] 处理通知失败: {e}")
|
|
||||||
|
|
||||||
async def _handle_chat_message(self, data: Dict[str, Any]):
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
sender_id = str(
|
|
||||||
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
|
|
||||||
)
|
|
||||||
if sender_id == self.client_self_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
room_id = data.get("toRoomId")
|
|
||||||
if room_id:
|
|
||||||
raw_text = data.get("text", "")
|
|
||||||
logger.debug(
|
|
||||||
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
message = await self.convert_room_message(data)
|
|
||||||
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
|
|
||||||
else:
|
|
||||||
message = await self.convert_chat_message(data)
|
|
||||||
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
|
|
||||||
|
|
||||||
event = MisskeyPlatformEvent(
|
|
||||||
message_str=message.message_str,
|
|
||||||
message_obj=message,
|
|
||||||
platform_meta=self.meta(),
|
|
||||||
session_id=message.session_id,
|
|
||||||
client=self.api,
|
|
||||||
)
|
|
||||||
self.commit_event(event)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
|
|
||||||
|
|
||||||
async def _debug_handler(self, data: Dict[str, Any]):
|
|
||||||
logger.debug(
|
|
||||||
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
|
|
||||||
text = note.get("text", "")
|
|
||||||
if not text:
|
|
||||||
return False
|
|
||||||
|
|
||||||
mentions = note.get("mentions", [])
|
|
||||||
if self._bot_username and f"@{self._bot_username}" in text:
|
|
||||||
return True
|
|
||||||
if self.client_self_id in [str(uid) for uid in mentions]:
|
|
||||||
return True
|
|
||||||
|
|
||||||
reply = note.get("reply")
|
|
||||||
if reply and isinstance(reply, dict):
|
|
||||||
reply_user_id = str(reply.get("user", {}).get("id", ""))
|
|
||||||
if reply_user_id == self.client_self_id:
|
|
||||||
return bool(self._bot_username and f"@{self._bot_username}" in text)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def send_by_session(
|
|
||||||
self, session: MessageSession, message_chain: MessageChain
|
|
||||||
) -> Awaitable[Any]:
|
|
||||||
if not self.api:
|
|
||||||
logger.error("[Misskey] API 客户端未初始化")
|
|
||||||
return await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
try:
|
|
||||||
session_id = session.session_id
|
|
||||||
text, has_at_user = serialize_message_chain(message_chain.chain)
|
|
||||||
|
|
||||||
if not has_at_user and session_id:
|
|
||||||
user_info = self._user_cache.get(session_id)
|
|
||||||
text = add_at_mention_if_needed(text, user_info, has_at_user)
|
|
||||||
|
|
||||||
if not text or not text.strip():
|
|
||||||
logger.warning("[Misskey] 消息内容为空,跳过发送")
|
|
||||||
return await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
if len(text) > self.max_message_length:
|
|
||||||
text = text[: self.max_message_length] + "..."
|
|
||||||
|
|
||||||
if session_id and is_valid_user_session_id(session_id):
|
|
||||||
from .misskey_utils import extract_user_id_from_session_id
|
|
||||||
|
|
||||||
user_id = extract_user_id_from_session_id(session_id)
|
|
||||||
await self.api.send_message(user_id, text)
|
|
||||||
elif session_id and is_valid_room_session_id(session_id):
|
|
||||||
from .misskey_utils import extract_room_id_from_session_id
|
|
||||||
|
|
||||||
room_id = extract_room_id_from_session_id(session_id)
|
|
||||||
await self.api.send_room_message(room_id, text)
|
|
||||||
else:
|
|
||||||
visibility, visible_user_ids = resolve_message_visibility(
|
|
||||||
user_id=session_id,
|
|
||||||
user_cache=self._user_cache,
|
|
||||||
self_id=self.client_self_id,
|
|
||||||
default_visibility=self.default_visibility,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.api.create_note(
|
|
||||||
text,
|
|
||||||
visibility=visibility,
|
|
||||||
visible_user_ids=visible_user_ids,
|
|
||||||
local_only=self.local_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Misskey] 发送消息失败: {e}")
|
|
||||||
|
|
||||||
return await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
|
||||||
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
|
|
||||||
sender_info = extract_sender_info(raw_data, is_chat=False)
|
|
||||||
message = create_base_message(
|
|
||||||
raw_data,
|
|
||||||
sender_info,
|
|
||||||
self.client_self_id,
|
|
||||||
is_chat=False,
|
|
||||||
unique_session=self.unique_session,
|
|
||||||
)
|
|
||||||
cache_user_info(
|
|
||||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
|
||||||
)
|
|
||||||
|
|
||||||
message_parts = []
|
|
||||||
raw_text = raw_data.get("text", "")
|
|
||||||
|
|
||||||
if raw_text:
|
|
||||||
text_parts, processed_text = process_at_mention(
|
|
||||||
message, raw_text, self._bot_username, self.client_self_id
|
|
||||||
)
|
|
||||||
message_parts.extend(text_parts)
|
|
||||||
|
|
||||||
files = raw_data.get("files", [])
|
|
||||||
file_parts = process_files(message, files)
|
|
||||||
message_parts.extend(file_parts)
|
|
||||||
|
|
||||||
message.message_str = (
|
|
||||||
" ".join(part for part in message_parts if part.strip())
|
|
||||||
if message_parts
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
|
||||||
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
|
|
||||||
sender_info = extract_sender_info(raw_data, is_chat=True)
|
|
||||||
message = create_base_message(
|
|
||||||
raw_data,
|
|
||||||
sender_info,
|
|
||||||
self.client_self_id,
|
|
||||||
is_chat=True,
|
|
||||||
unique_session=self.unique_session,
|
|
||||||
)
|
|
||||||
cache_user_info(
|
|
||||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
|
|
||||||
)
|
|
||||||
|
|
||||||
raw_text = raw_data.get("text", "")
|
|
||||||
if raw_text:
|
|
||||||
message.message.append(Comp.Plain(raw_text))
|
|
||||||
|
|
||||||
files = raw_data.get("files", [])
|
|
||||||
process_files(message, files, include_text_parts=False)
|
|
||||||
|
|
||||||
message.message_str = raw_text if raw_text else ""
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
|
||||||
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
|
|
||||||
sender_info = extract_sender_info(raw_data, is_chat=True)
|
|
||||||
room_id = raw_data.get("toRoomId", "")
|
|
||||||
message = create_base_message(
|
|
||||||
raw_data,
|
|
||||||
sender_info,
|
|
||||||
self.client_self_id,
|
|
||||||
is_chat=False,
|
|
||||||
room_id=room_id,
|
|
||||||
unique_session=self.unique_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_user_info(
|
|
||||||
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
|
||||||
)
|
|
||||||
cache_room_info(self._user_cache, raw_data, self.client_self_id)
|
|
||||||
|
|
||||||
raw_text = raw_data.get("text", "")
|
|
||||||
message_parts = []
|
|
||||||
|
|
||||||
if raw_text:
|
|
||||||
if self._bot_username and f"@{self._bot_username}" in raw_text:
|
|
||||||
text_parts, processed_text = process_at_mention(
|
|
||||||
message, raw_text, self._bot_username, self.client_self_id
|
|
||||||
)
|
|
||||||
message_parts.extend(text_parts)
|
|
||||||
else:
|
|
||||||
message.message.append(Comp.Plain(raw_text))
|
|
||||||
message_parts.append(raw_text)
|
|
||||||
|
|
||||||
files = raw_data.get("files", [])
|
|
||||||
file_parts = process_files(message, files)
|
|
||||||
message_parts.extend(file_parts)
|
|
||||||
|
|
||||||
message.message_str = (
|
|
||||||
" ".join(part for part in message_parts if part.strip())
|
|
||||||
if message_parts
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def terminate(self):
|
|
||||||
self._running = False
|
|
||||||
if self.api:
|
|
||||||
await self.api.close()
|
|
||||||
|
|
||||||
def get_client(self) -> Any:
|
|
||||||
return self.api
|
|
||||||
@@ -1,404 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import Any, Optional, Dict, List, Callable, Awaitable
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
try:
|
|
||||||
import aiohttp
|
|
||||||
import websockets
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
from astrbot.api import logger
|
|
||||||
|
|
||||||
# Constants
|
|
||||||
API_MAX_RETRIES = 3
|
|
||||||
HTTP_OK = 200
|
|
||||||
|
|
||||||
|
|
||||||
class APIError(Exception):
|
|
||||||
"""Misskey API 基础异常"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class APIConnectionError(APIError):
|
|
||||||
"""网络连接异常"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class APIRateLimitError(APIError):
|
|
||||||
"""API 频率限制异常"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(APIError):
|
|
||||||
"""认证失败异常"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketError(APIError):
|
|
||||||
"""WebSocket 连接异常"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingClient:
|
|
||||||
def __init__(self, instance_url: str, access_token: str):
|
|
||||||
self.instance_url = instance_url.rstrip("/")
|
|
||||||
self.access_token = access_token
|
|
||||||
self.websocket: Optional[Any] = None
|
|
||||||
self.is_connected = False
|
|
||||||
self.message_handlers: Dict[str, Callable] = {}
|
|
||||||
self.channels: Dict[str, str] = {}
|
|
||||||
self._running = False
|
|
||||||
self._last_pong = None
|
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
|
||||||
try:
|
|
||||||
ws_url = self.instance_url.replace("https://", "wss://").replace(
|
|
||||||
"http://", "ws://"
|
|
||||||
)
|
|
||||||
ws_url += f"/streaming?i={self.access_token}"
|
|
||||||
|
|
||||||
self.websocket = await websockets.connect(
|
|
||||||
ws_url, ping_interval=30, ping_timeout=10
|
|
||||||
)
|
|
||||||
self.is_connected = True
|
|
||||||
self._running = True
|
|
||||||
|
|
||||||
logger.info("[Misskey WebSocket] 已连接")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
|
|
||||||
self.is_connected = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def disconnect(self):
|
|
||||||
self._running = False
|
|
||||||
if self.websocket:
|
|
||||||
await self.websocket.close()
|
|
||||||
self.websocket = None
|
|
||||||
self.is_connected = False
|
|
||||||
logger.info("[Misskey WebSocket] 连接已断开")
|
|
||||||
|
|
||||||
async def subscribe_channel(
|
|
||||||
self, channel_type: str, params: Optional[Dict] = None
|
|
||||||
) -> str:
|
|
||||||
if not self.is_connected or not self.websocket:
|
|
||||||
raise WebSocketError("WebSocket 未连接")
|
|
||||||
|
|
||||||
channel_id = str(uuid.uuid4())
|
|
||||||
message = {
|
|
||||||
"type": "connect",
|
|
||||||
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.websocket.send(json.dumps(message))
|
|
||||||
self.channels[channel_id] = channel_type
|
|
||||||
return channel_id
|
|
||||||
|
|
||||||
async def unsubscribe_channel(self, channel_id: str):
|
|
||||||
if (
|
|
||||||
not self.is_connected
|
|
||||||
or not self.websocket
|
|
||||||
or channel_id not in self.channels
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
message = {"type": "disconnect", "body": {"id": channel_id}}
|
|
||||||
|
|
||||||
await self.websocket.send(json.dumps(message))
|
|
||||||
del self.channels[channel_id]
|
|
||||||
|
|
||||||
def add_message_handler(
|
|
||||||
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
|
|
||||||
):
|
|
||||||
self.message_handlers[event_type] = handler
|
|
||||||
|
|
||||||
async def listen(self):
|
|
||||||
if not self.is_connected or not self.websocket:
|
|
||||||
raise WebSocketError("WebSocket 未连接")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for message in self.websocket:
|
|
||||||
if not self._running:
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(message)
|
|
||||||
await self._handle_message(data)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
|
|
||||||
|
|
||||||
except websockets.exceptions.ConnectionClosedError as e:
|
|
||||||
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
|
|
||||||
self.is_connected = False
|
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
|
||||||
logger.warning(
|
|
||||||
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
|
|
||||||
)
|
|
||||||
self.is_connected = False
|
|
||||||
except websockets.exceptions.InvalidHandshake as e:
|
|
||||||
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
|
|
||||||
self.is_connected = False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
|
|
||||||
self.is_connected = False
|
|
||||||
|
|
||||||
async def _handle_message(self, data: Dict[str, Any]):
|
|
||||||
message_type = data.get("type")
|
|
||||||
body = data.get("body", {})
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if message_type == "channel":
|
|
||||||
channel_id = body.get("id")
|
|
||||||
event_type = body.get("type")
|
|
||||||
event_body = body.get("body", {})
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if channel_id in self.channels:
|
|
||||||
channel_type = self.channels[channel_id]
|
|
||||||
handler_key = f"{channel_type}:{event_type}"
|
|
||||||
|
|
||||||
if handler_key in self.message_handlers:
|
|
||||||
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
|
|
||||||
await self.message_handlers[handler_key](event_body)
|
|
||||||
elif event_type in self.message_handlers:
|
|
||||||
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
|
|
||||||
await self.message_handlers[event_type](event_body)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}"
|
|
||||||
)
|
|
||||||
if "_debug" in self.message_handlers:
|
|
||||||
await self.message_handlers["_debug"](
|
|
||||||
{
|
|
||||||
"type": event_type,
|
|
||||||
"body": event_body,
|
|
||||||
"channel": channel_type,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
elif message_type in self.message_handlers:
|
|
||||||
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
|
|
||||||
await self.message_handlers[message_type](body)
|
|
||||||
else:
|
|
||||||
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
|
|
||||||
if "_debug" in self.message_handlers:
|
|
||||||
await self.message_handlers["_debug"](data)
|
|
||||||
|
|
||||||
|
|
||||||
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
|
|
||||||
def decorator(func):
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
last_exc = None
|
|
||||||
for _ in range(max_retries):
|
|
||||||
try:
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
except retryable_exceptions as e:
|
|
||||||
last_exc = e
|
|
||||||
continue
|
|
||||||
if last_exc:
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
class MisskeyAPI:
|
|
||||||
def __init__(self, instance_url: str, access_token: str):
|
|
||||||
self.instance_url = instance_url.rstrip("/")
|
|
||||||
self.access_token = access_token
|
|
||||||
self._session: Optional[aiohttp.ClientSession] = None
|
|
||||||
self.streaming: Optional[StreamingClient] = None
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
await self.close()
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
if self.streaming:
|
|
||||||
await self.streaming.disconnect()
|
|
||||||
self.streaming = None
|
|
||||||
if self._session:
|
|
||||||
await self._session.close()
|
|
||||||
self._session = None
|
|
||||||
logger.debug("[Misskey API] 客户端已关闭")
|
|
||||||
|
|
||||||
def get_streaming_client(self) -> StreamingClient:
|
|
||||||
if not self.streaming:
|
|
||||||
self.streaming = StreamingClient(self.instance_url, self.access_token)
|
|
||||||
return self.streaming
|
|
||||||
|
|
||||||
@property
|
|
||||||
def session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._session is None or self._session.closed:
|
|
||||||
headers = {"Authorization": f"Bearer {self.access_token}"}
|
|
||||||
self._session = aiohttp.ClientSession(headers=headers)
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
def _handle_response_status(self, status: int, endpoint: str):
|
|
||||||
"""处理 HTTP 响应状态码"""
|
|
||||||
if status == 400:
|
|
||||||
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
|
|
||||||
raise APIError(f"Bad request for {endpoint}")
|
|
||||||
elif status in (401, 403):
|
|
||||||
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
|
|
||||||
raise AuthenticationError(f"Authentication failed for {endpoint}")
|
|
||||||
elif status == 429:
|
|
||||||
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
|
|
||||||
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
|
|
||||||
else:
|
|
||||||
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
|
|
||||||
raise APIConnectionError(f"HTTP {status} for {endpoint}")
|
|
||||||
|
|
||||||
async def _process_response(
|
|
||||||
self, response: aiohttp.ClientResponse, endpoint: str
|
|
||||||
) -> Any:
|
|
||||||
"""处理 API 响应"""
|
|
||||||
if response.status == HTTP_OK:
|
|
||||||
try:
|
|
||||||
result = await response.json()
|
|
||||||
if endpoint == "i/notifications":
|
|
||||||
notifications_data = (
|
|
||||||
result
|
|
||||||
if isinstance(result, list)
|
|
||||||
else result.get("notifications", [])
|
|
||||||
if isinstance(result, dict)
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
if notifications_data:
|
|
||||||
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
|
|
||||||
else:
|
|
||||||
logger.debug(f"API 请求成功: {endpoint}")
|
|
||||||
return result
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"响应不是有效的 JSON 格式: {e}")
|
|
||||||
raise APIConnectionError("Invalid JSON response") from e
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
error_text = await response.text()
|
|
||||||
logger.error(
|
|
||||||
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
|
|
||||||
|
|
||||||
self._handle_response_status(response.status, endpoint)
|
|
||||||
raise APIConnectionError(f"Request failed for {endpoint}")
|
|
||||||
|
|
||||||
@retry_async(
|
|
||||||
max_retries=API_MAX_RETRIES,
|
|
||||||
retryable_exceptions=(APIConnectionError, APIRateLimitError),
|
|
||||||
)
|
|
||||||
async def _make_request(
|
|
||||||
self, endpoint: str, data: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Any:
|
|
||||||
url = f"{self.instance_url}/api/{endpoint}"
|
|
||||||
payload = {"i": self.access_token}
|
|
||||||
if data:
|
|
||||||
payload.update(data)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.session.post(url, json=payload) as response:
|
|
||||||
return await self._process_response(response, endpoint)
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.error(f"HTTP 请求错误: {e}")
|
|
||||||
raise APIConnectionError(f"HTTP request failed: {e}") from e
|
|
||||||
|
|
||||||
async def create_note(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
visibility: str = "public",
|
|
||||||
reply_id: Optional[str] = None,
|
|
||||||
visible_user_ids: Optional[List[str]] = None,
|
|
||||||
local_only: bool = False,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""创建新贴文"""
|
|
||||||
data: Dict[str, Any] = {
|
|
||||||
"text": text,
|
|
||||||
"visibility": visibility,
|
|
||||||
"localOnly": local_only,
|
|
||||||
}
|
|
||||||
if reply_id:
|
|
||||||
data["replyId"] = reply_id
|
|
||||||
if visible_user_ids and visibility == "specified":
|
|
||||||
data["visibleUserIds"] = visible_user_ids
|
|
||||||
|
|
||||||
result = await self._make_request("notes/create", data)
|
|
||||||
note_id = result.get("createdNote", {}).get("id", "unknown")
|
|
||||||
logger.debug(f"发帖成功,note_id: {note_id}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def get_current_user(self) -> Dict[str, Any]:
|
|
||||||
"""获取当前用户信息"""
|
|
||||||
return await self._make_request("i", {})
|
|
||||||
|
|
||||||
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
|
|
||||||
"""发送聊天消息"""
|
|
||||||
result = await self._make_request(
|
|
||||||
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
|
|
||||||
)
|
|
||||||
message_id = result.get("id", "unknown")
|
|
||||||
logger.debug(f"聊天发送成功,message_id: {message_id}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
|
|
||||||
"""发送房间消息"""
|
|
||||||
result = await self._make_request(
|
|
||||||
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
|
|
||||||
)
|
|
||||||
message_id = result.get("id", "unknown")
|
|
||||||
logger.debug(f"房间消息发送成功,message_id: {message_id}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def get_messages(
|
|
||||||
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""获取聊天消息历史"""
|
|
||||||
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
|
|
||||||
if since_id:
|
|
||||||
data["sinceId"] = since_id
|
|
||||||
|
|
||||||
result = await self._make_request("chat/messages/user-timeline", data)
|
|
||||||
if isinstance(result, list):
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def get_mentions(
|
|
||||||
self, limit: int = 10, since_id: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""获取提及通知"""
|
|
||||||
data: Dict[str, Any] = {"limit": limit}
|
|
||||||
if since_id:
|
|
||||||
data["sinceId"] = since_id
|
|
||||||
data["includeTypes"] = ["mention", "reply", "quote"]
|
|
||||||
|
|
||||||
result = await self._make_request("i/notifications", data)
|
|
||||||
if isinstance(result, list):
|
|
||||||
return result
|
|
||||||
elif isinstance(result, dict) and "notifications" in result:
|
|
||||||
return result["notifications"]
|
|
||||||
else:
|
|
||||||
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
|
|
||||||
return []
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import re
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
|
||||||
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
|
|
||||||
from astrbot.api.message_components import Plain
|
|
||||||
|
|
||||||
from .misskey_utils import (
|
|
||||||
serialize_message_chain,
|
|
||||||
resolve_visibility_from_raw_message,
|
|
||||||
is_valid_user_session_id,
|
|
||||||
is_valid_room_session_id,
|
|
||||||
add_at_mention_if_needed,
|
|
||||||
extract_user_id_from_session_id,
|
|
||||||
extract_room_id_from_session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MisskeyPlatformEvent(AstrMessageEvent):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message_str: str,
|
|
||||||
message_obj: AstrBotMessage,
|
|
||||||
platform_meta: PlatformMetadata,
|
|
||||||
session_id: str,
|
|
||||||
client,
|
|
||||||
):
|
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
||||||
self.client = client
|
|
||||||
|
|
||||||
def _is_system_command(self, message_str: str) -> bool:
|
|
||||||
"""检测是否为系统指令"""
|
|
||||||
if not message_str or not message_str.strip():
|
|
||||||
return False
|
|
||||||
|
|
||||||
system_prefixes = ["/", "!", "#", ".", "^"]
|
|
||||||
message_trimmed = message_str.strip()
|
|
||||||
|
|
||||||
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
|
||||||
content, has_at = serialize_message_chain(message.chain)
|
|
||||||
|
|
||||||
if not content:
|
|
||||||
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
original_message_id = getattr(self.message_obj, "message_id", None)
|
|
||||||
raw_message = getattr(self.message_obj, "raw_message", {})
|
|
||||||
|
|
||||||
if raw_message and not has_at:
|
|
||||||
user_data = raw_message.get("user", {})
|
|
||||||
user_info = {
|
|
||||||
"username": user_data.get("username", ""),
|
|
||||||
"nickname": user_data.get("name", user_data.get("username", "")),
|
|
||||||
}
|
|
||||||
content = add_at_mention_if_needed(content, user_info, has_at)
|
|
||||||
|
|
||||||
# 根据会话类型选择发送方式
|
|
||||||
if hasattr(self.client, "send_message") and is_valid_user_session_id(
|
|
||||||
self.session_id
|
|
||||||
):
|
|
||||||
user_id = extract_user_id_from_session_id(self.session_id)
|
|
||||||
await self.client.send_message(user_id, content)
|
|
||||||
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
|
|
||||||
self.session_id
|
|
||||||
):
|
|
||||||
room_id = extract_room_id_from_session_id(self.session_id)
|
|
||||||
await self.client.send_room_message(room_id, content)
|
|
||||||
elif original_message_id and hasattr(self.client, "create_note"):
|
|
||||||
visibility, visible_user_ids = resolve_visibility_from_raw_message(
|
|
||||||
raw_message
|
|
||||||
)
|
|
||||||
await self.client.create_note(
|
|
||||||
content,
|
|
||||||
reply_id=original_message_id,
|
|
||||||
visibility=visibility,
|
|
||||||
visible_user_ids=visible_user_ids,
|
|
||||||
)
|
|
||||||
elif hasattr(self.client, "create_note"):
|
|
||||||
logger.debug("[MisskeyEvent] 创建新帖子")
|
|
||||||
await self.client.create_note(content)
|
|
||||||
|
|
||||||
await super().send(message)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[MisskeyEvent] 发送失败: {e}")
|
|
||||||
|
|
||||||
async def send_streaming(
|
|
||||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
|
||||||
):
|
|
||||||
if not use_fallback:
|
|
||||||
buffer = None
|
|
||||||
async for chain in generator:
|
|
||||||
if not buffer:
|
|
||||||
buffer = chain
|
|
||||||
else:
|
|
||||||
buffer.chain.extend(chain.chain)
|
|
||||||
if not buffer:
|
|
||||||
return
|
|
||||||
buffer.squash_plain()
|
|
||||||
await self.send(buffer)
|
|
||||||
return await super().send_streaming(generator, use_fallback)
|
|
||||||
|
|
||||||
buffer = ""
|
|
||||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
|
||||||
|
|
||||||
async for chain in generator:
|
|
||||||
if isinstance(chain, MessageChain):
|
|
||||||
for comp in chain.chain:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
buffer += comp.text
|
|
||||||
if any(p in buffer for p in "。?!~…"):
|
|
||||||
buffer = await self.process_buffer(buffer, pattern)
|
|
||||||
else:
|
|
||||||
await self.send(MessageChain(chain=[comp]))
|
|
||||||
await asyncio.sleep(1.5) # 限速
|
|
||||||
|
|
||||||
if buffer.strip():
|
|
||||||
await self.send(MessageChain([Plain(buffer)]))
|
|
||||||
return await super().send_streaming(generator, use_fallback)
|
|
||||||
@@ -1,327 +0,0 @@
|
|||||||
"""Misskey 平台适配器通用工具函数"""
|
|
||||||
|
|
||||||
from typing import Dict, Any, List, Tuple, Optional, Union
|
|
||||||
import astrbot.api.message_components as Comp
|
|
||||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
|
||||||
"""将消息链序列化为文本字符串"""
|
|
||||||
text_parts = []
|
|
||||||
has_at = False
|
|
||||||
|
|
||||||
def process_component(component):
|
|
||||||
nonlocal has_at
|
|
||||||
if isinstance(component, Comp.Plain):
|
|
||||||
return component.text
|
|
||||||
elif isinstance(component, Comp.File):
|
|
||||||
file_name = getattr(component, "name", "文件")
|
|
||||||
return f"[文件: {file_name}]"
|
|
||||||
elif isinstance(component, Comp.At):
|
|
||||||
has_at = True
|
|
||||||
return f"@{component.qq}"
|
|
||||||
elif hasattr(component, "text"):
|
|
||||||
text = getattr(component, "text", "")
|
|
||||||
if "@" in text:
|
|
||||||
has_at = True
|
|
||||||
return text
|
|
||||||
else:
|
|
||||||
return str(component)
|
|
||||||
|
|
||||||
for component in chain:
|
|
||||||
if isinstance(component, Comp.Node) and component.content:
|
|
||||||
for node_comp in component.content:
|
|
||||||
result = process_component(node_comp)
|
|
||||||
if result:
|
|
||||||
text_parts.append(result)
|
|
||||||
else:
|
|
||||||
result = process_component(component)
|
|
||||||
if result:
|
|
||||||
text_parts.append(result)
|
|
||||||
|
|
||||||
return "".join(text_parts), has_at
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_message_visibility(
|
|
||||||
user_id: Optional[str],
|
|
||||||
user_cache: Dict[str, Any],
|
|
||||||
self_id: Optional[str],
|
|
||||||
default_visibility: str = "public",
|
|
||||||
) -> Tuple[str, Optional[List[str]]]:
|
|
||||||
"""解析 Misskey 消息的可见性设置"""
|
|
||||||
visibility = default_visibility
|
|
||||||
visible_user_ids = None
|
|
||||||
|
|
||||||
if user_id and user_cache:
|
|
||||||
user_info = user_cache.get(user_id)
|
|
||||||
if user_info:
|
|
||||||
original_visibility = user_info.get("visibility", default_visibility)
|
|
||||||
if original_visibility == "specified":
|
|
||||||
visibility = "specified"
|
|
||||||
original_visible_users = user_info.get("visible_user_ids", [])
|
|
||||||
users_to_include = [user_id]
|
|
||||||
if self_id:
|
|
||||||
users_to_include.append(self_id)
|
|
||||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
|
||||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
|
||||||
else:
|
|
||||||
visibility = original_visibility
|
|
||||||
|
|
||||||
return visibility, visible_user_ids
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_visibility_from_raw_message(
|
|
||||||
raw_message: Dict[str, Any], self_id: Optional[str] = None
|
|
||||||
) -> Tuple[str, Optional[List[str]]]:
|
|
||||||
"""从原始消息数据中解析可见性设置"""
|
|
||||||
visibility = "public"
|
|
||||||
visible_user_ids = None
|
|
||||||
|
|
||||||
if not raw_message:
|
|
||||||
return visibility, visible_user_ids
|
|
||||||
|
|
||||||
original_visibility = raw_message.get("visibility", "public")
|
|
||||||
if original_visibility == "specified":
|
|
||||||
visibility = "specified"
|
|
||||||
original_visible_users = raw_message.get("visibleUserIds", [])
|
|
||||||
sender_id = raw_message.get("userId", "")
|
|
||||||
|
|
||||||
users_to_include = []
|
|
||||||
if sender_id:
|
|
||||||
users_to_include.append(sender_id)
|
|
||||||
if self_id:
|
|
||||||
users_to_include.append(self_id)
|
|
||||||
|
|
||||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
|
||||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
|
||||||
else:
|
|
||||||
visibility = original_visibility
|
|
||||||
|
|
||||||
return visibility, visible_user_ids
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
|
|
||||||
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
|
|
||||||
if not isinstance(session_id, str) or "%" not in session_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
parts = session_id.split("%")
|
|
||||||
return (
|
|
||||||
len(parts) == 2
|
|
||||||
and parts[0] == "chat"
|
|
||||||
and bool(parts[1])
|
|
||||||
and parts[1] != "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
|
|
||||||
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
|
|
||||||
if not isinstance(session_id, str) or "%" not in session_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
parts = session_id.split("%")
|
|
||||||
return (
|
|
||||||
len(parts) == 2
|
|
||||||
and parts[0] == "room"
|
|
||||||
and bool(parts[1])
|
|
||||||
and parts[1] != "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_user_id_from_session_id(session_id: str) -> str:
|
|
||||||
"""从 session_id 中提取用户 ID"""
|
|
||||||
if "%" in session_id:
|
|
||||||
parts = session_id.split("%")
|
|
||||||
if len(parts) >= 2:
|
|
||||||
return parts[1]
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
|
|
||||||
def extract_room_id_from_session_id(session_id: str) -> str:
|
|
||||||
"""从 session_id 中提取房间 ID"""
|
|
||||||
if "%" in session_id:
|
|
||||||
parts = session_id.split("%")
|
|
||||||
if len(parts) >= 2 and parts[0] == "room":
|
|
||||||
return parts[1]
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
|
|
||||||
def add_at_mention_if_needed(
|
|
||||||
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""如果需要且没有@用户,则添加@用户"""
|
|
||||||
if has_at or not user_info:
|
|
||||||
return text
|
|
||||||
|
|
||||||
username = user_info.get("username")
|
|
||||||
nickname = user_info.get("nickname")
|
|
||||||
|
|
||||||
if username:
|
|
||||||
mention = f"@{username}"
|
|
||||||
if not text.startswith(mention):
|
|
||||||
text = f"{mention}\n{text}".strip()
|
|
||||||
elif nickname:
|
|
||||||
mention = f"@{nickname}"
|
|
||||||
if not text.startswith(mention):
|
|
||||||
text = f"{mention}\n{text}".strip()
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
|
|
||||||
"""创建文件组件和描述文本"""
|
|
||||||
file_url = file_info.get("url", "")
|
|
||||||
file_name = file_info.get("name", "未知文件")
|
|
||||||
file_type = file_info.get("type", "")
|
|
||||||
|
|
||||||
if file_type.startswith("image/"):
|
|
||||||
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
|
|
||||||
elif file_type.startswith("audio/"):
|
|
||||||
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
|
|
||||||
elif file_type.startswith("video/"):
|
|
||||||
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
|
|
||||||
else:
|
|
||||||
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
|
|
||||||
|
|
||||||
|
|
||||||
def process_files(
|
|
||||||
message: AstrBotMessage, files: list, include_text_parts: bool = True
|
|
||||||
) -> list:
|
|
||||||
"""处理文件列表,添加到消息组件中并返回文本描述"""
|
|
||||||
file_parts = []
|
|
||||||
for file_info in files:
|
|
||||||
component, part_text = create_file_component(file_info)
|
|
||||||
message.message.append(component)
|
|
||||||
if include_text_parts:
|
|
||||||
file_parts.append(part_text)
|
|
||||||
return file_parts
|
|
||||||
|
|
||||||
|
|
||||||
def extract_sender_info(
|
|
||||||
raw_data: Dict[str, Any], is_chat: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""提取发送者信息"""
|
|
||||||
if is_chat:
|
|
||||||
sender = raw_data.get("fromUser", {})
|
|
||||||
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
|
|
||||||
else:
|
|
||||||
sender = raw_data.get("user", {})
|
|
||||||
sender_id = str(sender.get("id", ""))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"sender": sender,
|
|
||||||
"sender_id": sender_id,
|
|
||||||
"nickname": sender.get("name", sender.get("username", "")),
|
|
||||||
"username": sender.get("username", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_base_message(
|
|
||||||
raw_data: Dict[str, Any],
|
|
||||||
sender_info: Dict[str, Any],
|
|
||||||
client_self_id: str,
|
|
||||||
is_chat: bool = False,
|
|
||||||
room_id: Optional[str] = None,
|
|
||||||
unique_session: bool = False,
|
|
||||||
) -> AstrBotMessage:
|
|
||||||
"""创建基础消息对象"""
|
|
||||||
message = AstrBotMessage()
|
|
||||||
message.raw_message = raw_data
|
|
||||||
message.message = []
|
|
||||||
|
|
||||||
message.sender = MessageMember(
|
|
||||||
user_id=sender_info["sender_id"],
|
|
||||||
nickname=sender_info["nickname"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if room_id:
|
|
||||||
session_prefix = "room"
|
|
||||||
session_id = f"{session_prefix}%{room_id}"
|
|
||||||
if unique_session:
|
|
||||||
session_id += f"_{sender_info['sender_id']}"
|
|
||||||
message.type = MessageType.GROUP_MESSAGE
|
|
||||||
message.group_id = room_id
|
|
||||||
elif is_chat:
|
|
||||||
session_prefix = "chat"
|
|
||||||
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
|
||||||
message.type = MessageType.FRIEND_MESSAGE
|
|
||||||
else:
|
|
||||||
session_prefix = "note"
|
|
||||||
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
|
||||||
message.type = MessageType.FRIEND_MESSAGE
|
|
||||||
|
|
||||||
message.session_id = (
|
|
||||||
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
|
|
||||||
)
|
|
||||||
message.message_id = str(raw_data.get("id", ""))
|
|
||||||
message.self_id = client_self_id
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def process_at_mention(
|
|
||||||
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
|
|
||||||
) -> Tuple[List[str], str]:
|
|
||||||
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
|
|
||||||
message_parts = []
|
|
||||||
|
|
||||||
if not raw_text:
|
|
||||||
return message_parts, ""
|
|
||||||
|
|
||||||
if bot_username and raw_text.startswith(f"@{bot_username}"):
|
|
||||||
at_mention = f"@{bot_username}"
|
|
||||||
message.message.append(Comp.At(qq=client_self_id))
|
|
||||||
remaining_text = raw_text[len(at_mention) :].strip()
|
|
||||||
if remaining_text:
|
|
||||||
message.message.append(Comp.Plain(remaining_text))
|
|
||||||
message_parts.append(remaining_text)
|
|
||||||
return message_parts, remaining_text
|
|
||||||
else:
|
|
||||||
message.message.append(Comp.Plain(raw_text))
|
|
||||||
message_parts.append(raw_text)
|
|
||||||
return message_parts, raw_text
|
|
||||||
|
|
||||||
|
|
||||||
def cache_user_info(
|
|
||||||
user_cache: Dict[str, Any],
|
|
||||||
sender_info: Dict[str, Any],
|
|
||||||
raw_data: Dict[str, Any],
|
|
||||||
client_self_id: str,
|
|
||||||
is_chat: bool = False,
|
|
||||||
):
|
|
||||||
"""缓存用户信息"""
|
|
||||||
if is_chat:
|
|
||||||
user_cache_data = {
|
|
||||||
"username": sender_info["username"],
|
|
||||||
"nickname": sender_info["nickname"],
|
|
||||||
"visibility": "specified",
|
|
||||||
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
user_cache_data = {
|
|
||||||
"username": sender_info["username"],
|
|
||||||
"nickname": sender_info["nickname"],
|
|
||||||
"visibility": raw_data.get("visibility", "public"),
|
|
||||||
"visible_user_ids": raw_data.get("visibleUserIds", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
user_cache[sender_info["sender_id"]] = user_cache_data
|
|
||||||
|
|
||||||
|
|
||||||
def cache_room_info(
|
|
||||||
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
|
|
||||||
):
|
|
||||||
"""缓存房间信息"""
|
|
||||||
room_data = raw_data.get("toRoom")
|
|
||||||
room_id = raw_data.get("toRoomId")
|
|
||||||
|
|
||||||
if room_data and room_id:
|
|
||||||
room_cache_key = f"room:{room_id}"
|
|
||||||
user_cache[room_cache_key] = {
|
|
||||||
"room_id": room_id,
|
|
||||||
"room_name": room_data.get("name", ""),
|
|
||||||
"room_description": room_data.get("description", ""),
|
|
||||||
"owner_id": room_data.get("ownerId", ""),
|
|
||||||
"visibility": "specified",
|
|
||||||
"visible_user_ids": [client_self_id],
|
|
||||||
}
|
|
||||||
@@ -94,15 +94,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
plain_text,
|
plain_text,
|
||||||
image_base64,
|
image_base64,
|
||||||
image_path,
|
image_path,
|
||||||
record_file_path,
|
record_file_path
|
||||||
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||||
|
|
||||||
if (
|
if not plain_text and not image_base64 and not image_path and not record_file_path:
|
||||||
not plain_text
|
|
||||||
and not image_base64
|
|
||||||
and not image_path
|
|
||||||
and not record_file_path
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -123,7 +118,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
if record_file_path: # group record msg
|
if record_file_path: # group record msg
|
||||||
media = await self.upload_group_and_c2c_record(
|
media = await self.upload_group_and_c2c_record(
|
||||||
record_file_path, 3, group_openid=source.group_openid
|
record_file_path, 3, group_openid=source.group_openid
|
||||||
)
|
)
|
||||||
@@ -139,9 +134,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
if record_file_path: # c2c record
|
if record_file_path: # c2c record
|
||||||
media = await self.upload_group_and_c2c_record(
|
media = await self.upload_group_and_c2c_record(
|
||||||
record_file_path, 3, openid=source.author.user_openid
|
record_file_path, 3, openid = source.author.user_openid
|
||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
@@ -195,55 +190,58 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
return await self.bot.api._http.request(route, json=payload)
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
async def upload_group_and_c2c_record(
|
async def upload_group_and_c2c_record(
|
||||||
self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs
|
self,
|
||||||
|
file_source: str,
|
||||||
|
file_type: int,
|
||||||
|
srv_send_msg: bool = False,
|
||||||
|
**kwargs
|
||||||
) -> Optional[Media]:
|
) -> Optional[Media]:
|
||||||
"""
|
"""
|
||||||
上传媒体文件
|
上传媒体文件
|
||||||
"""
|
"""
|
||||||
# 构建基础payload
|
# 构建基础payload
|
||||||
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
|
payload = {
|
||||||
|
"file_type": file_type,
|
||||||
|
"srv_send_msg": srv_send_msg
|
||||||
|
}
|
||||||
|
|
||||||
# 处理文件数据
|
# 处理文件数据
|
||||||
if os.path.exists(file_source):
|
if os.path.exists(file_source):
|
||||||
# 读取本地文件
|
# 读取本地文件
|
||||||
async with aiofiles.open(file_source, "rb") as f:
|
async with aiofiles.open(file_source, 'rb') as f:
|
||||||
file_content = await f.read()
|
file_content = await f.read()
|
||||||
# use base64 encode
|
# use base64 encode
|
||||||
payload["file_data"] = base64.b64encode(file_content).decode("utf-8")
|
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
|
||||||
else:
|
else:
|
||||||
# 使用URL
|
# 使用URL
|
||||||
payload["url"] = file_source
|
payload["url"] = file_source
|
||||||
|
|
||||||
# 添加接收者信息和确定路由
|
# 添加接收者信息和确定路由
|
||||||
if "openid" in kwargs:
|
if "openid" in kwargs:
|
||||||
payload["openid"] = kwargs["openid"]
|
payload["openid"] = kwargs["openid"]
|
||||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||||
elif "group_openid" in kwargs:
|
elif "group_openid" in kwargs:
|
||||||
payload["group_openid"] = kwargs["group_openid"]
|
payload["group_openid"] =kwargs["group_openid"]
|
||||||
route = Route(
|
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
|
||||||
"POST",
|
|
||||||
"/v2/groups/{group_openid}/files",
|
|
||||||
group_openid=kwargs["group_openid"],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用底层HTTP请求
|
# 使用底层HTTP请求
|
||||||
result = await self.bot.api._http.request(route, json=payload)
|
result = await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
return Media(
|
return Media(
|
||||||
file_uuid=result.get("file_uuid"),
|
file_uuid=result.get("file_uuid"),
|
||||||
file_info=result.get("file_info"),
|
file_info=result.get("file_info"),
|
||||||
ttl=result.get("ttl", 0),
|
ttl=result.get("ttl", 0),
|
||||||
file_id=result.get("id", ""),
|
file_id=result.get("id", "")
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"上传请求错误: {e}")
|
logger.error(f"上传请求错误: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def post_c2c_message(
|
async def post_c2c_message(
|
||||||
self,
|
self,
|
||||||
openid: str,
|
openid: str,
|
||||||
@@ -288,23 +286,19 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
image_base64 = image_base64.removeprefix("base64://")
|
image_base64 = image_base64.removeprefix("base64://")
|
||||||
elif isinstance(i, Record):
|
elif isinstance(i, Record):
|
||||||
if i.file:
|
if i.file:
|
||||||
record_wav_path = await i.convert_to_file_path() # wav 路径
|
record_wav_path = await i.convert_to_file_path() # wav 路径
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
record_tecent_silk_path = os.path.join(
|
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
||||||
temp_dir, f"{uuid.uuid4()}.silk"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
duration = await wav_to_tencent_silk(
|
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
|
||||||
record_wav_path, record_tecent_silk_path
|
|
||||||
)
|
|
||||||
if duration > 0:
|
if duration > 0:
|
||||||
record_file_path = record_tecent_silk_path
|
record_file_path = record_tecent_silk_path
|
||||||
else:
|
else:
|
||||||
record_file_path = None
|
record_file_path = None
|
||||||
logger.error("转换音频格式时出错:音频时长不大于0")
|
logger.error("转换音频格式时出错:音频时长不大于0")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理语音时出错: {e}")
|
logger.error(f"处理语音时出错: {e}")
|
||||||
record_file_path = None
|
record_file_path = None
|
||||||
else:
|
else:
|
||||||
logger.debug(f"qq_official 忽略 {i.type}")
|
logger.debug(f"qq_official 忽略 {i.type}")
|
||||||
return plain_text, image_base64, image_file_path, record_file_path
|
return plain_text, image_base64, image_file_path, record_file_path
|
||||||
|
|||||||
@@ -1,748 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import websockets
|
|
||||||
from websockets.asyncio.client import connect
|
|
||||||
from typing import Optional
|
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
|
||||||
from websockets.asyncio.client import ClientConnection
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.api.event import MessageChain
|
|
||||||
from astrbot.api.platform import (
|
|
||||||
AstrBotMessage,
|
|
||||||
MessageMember,
|
|
||||||
MessageType,
|
|
||||||
Platform,
|
|
||||||
PlatformMetadata,
|
|
||||||
register_platform_adapter,
|
|
||||||
)
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSession
|
|
||||||
from astrbot.api.message_components import (
|
|
||||||
Plain,
|
|
||||||
Image,
|
|
||||||
At,
|
|
||||||
File,
|
|
||||||
Record,
|
|
||||||
Reply,
|
|
||||||
)
|
|
||||||
from xml.etree import ElementTree as ET
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter(
|
|
||||||
"satori",
|
|
||||||
"Satori 协议适配器",
|
|
||||||
)
|
|
||||||
class SatoriPlatformAdapter(Platform):
|
|
||||||
def __init__(
|
|
||||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
|
||||||
) -> None:
|
|
||||||
super().__init__(event_queue)
|
|
||||||
self.config = platform_config
|
|
||||||
self.settings = platform_settings
|
|
||||||
|
|
||||||
self.api_base_url = self.config.get(
|
|
||||||
"satori_api_base_url", "http://localhost:5140/satori/v1"
|
|
||||||
)
|
|
||||||
self.token = self.config.get("satori_token", "")
|
|
||||||
self.endpoint = self.config.get(
|
|
||||||
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
|
|
||||||
)
|
|
||||||
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
|
||||||
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
|
||||||
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
|
||||||
|
|
||||||
self.metadata = PlatformMetadata(
|
|
||||||
name="satori",
|
|
||||||
description="Satori 通用协议适配器",
|
|
||||||
id=self.config["id"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ws: Optional[ClientConnection] = None
|
|
||||||
self.session: Optional[ClientSession] = None
|
|
||||||
self.sequence = 0
|
|
||||||
self.logins = []
|
|
||||||
self.running = False
|
|
||||||
self.heartbeat_task: Optional[asyncio.Task] = None
|
|
||||||
self.ready_received = False
|
|
||||||
|
|
||||||
async def send_by_session(
|
|
||||||
self, session: MessageSession, message_chain: MessageChain
|
|
||||||
):
|
|
||||||
from .satori_event import SatoriPlatformEvent
|
|
||||||
|
|
||||||
await SatoriPlatformEvent.send_with_adapter(
|
|
||||||
self, message_chain, session.session_id
|
|
||||||
)
|
|
||||||
await super().send_by_session(session, message_chain)
|
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
|
||||||
return self.metadata
|
|
||||||
|
|
||||||
def _is_websocket_closed(self, ws) -> bool:
|
|
||||||
"""检查WebSocket连接是否已关闭"""
|
|
||||||
if not ws:
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
if hasattr(ws, "closed"):
|
|
||||||
return ws.closed
|
|
||||||
elif hasattr(ws, "close_code"):
|
|
||||||
return ws.close_code is not None
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
except AttributeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
self.running = True
|
|
||||||
self.session = ClientSession(timeout=ClientTimeout(total=30))
|
|
||||||
|
|
||||||
retry_count = 0
|
|
||||||
max_retries = 10
|
|
||||||
|
|
||||||
while self.running:
|
|
||||||
try:
|
|
||||||
await self.connect_websocket()
|
|
||||||
retry_count = 0
|
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
|
||||||
logger.warning(f"Satori WebSocket 连接关闭: {e}")
|
|
||||||
retry_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori WebSocket 连接失败: {e}")
|
|
||||||
retry_count += 1
|
|
||||||
|
|
||||||
if not self.running:
|
|
||||||
break
|
|
||||||
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.error(f"达到最大重试次数 ({max_retries}),停止重试")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not self.auto_reconnect:
|
|
||||||
break
|
|
||||||
|
|
||||||
delay = min(self.reconnect_delay * (2 ** (retry_count - 1)), 60)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
if self.session:
|
|
||||||
await self.session.close()
|
|
||||||
|
|
||||||
async def connect_websocket(self):
|
|
||||||
logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}")
|
|
||||||
logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}")
|
|
||||||
|
|
||||||
if not self.endpoint.startswith(("ws://", "wss://")):
|
|
||||||
logger.error(f"无效的WebSocket URL: {self.endpoint}")
|
|
||||||
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
websocket = await connect(self.endpoint, additional_headers={})
|
|
||||||
self.ws = websocket
|
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
await self.send_identify()
|
|
||||||
|
|
||||||
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
|
|
||||||
|
|
||||||
async for message in websocket:
|
|
||||||
try:
|
|
||||||
await self.handle_message(message) # type: ignore
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori 处理消息异常: {e}")
|
|
||||||
|
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
|
||||||
logger.warning(f"Satori WebSocket 连接关闭: {e}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori WebSocket 连接异常: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
if self.heartbeat_task:
|
|
||||||
self.heartbeat_task.cancel()
|
|
||||||
try:
|
|
||||||
await self.heartbeat_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
if self.ws:
|
|
||||||
try:
|
|
||||||
await self.ws.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori WebSocket 关闭异常: {e}")
|
|
||||||
|
|
||||||
async def send_identify(self):
|
|
||||||
if not self.ws:
|
|
||||||
raise Exception("WebSocket连接未建立")
|
|
||||||
|
|
||||||
if self._is_websocket_closed(self.ws):
|
|
||||||
raise Exception("WebSocket连接已关闭")
|
|
||||||
|
|
||||||
identify_payload = {
|
|
||||||
"op": 3, # IDENTIFY
|
|
||||||
"body": {
|
|
||||||
"token": str(self.token) if self.token else "", # 字符串
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# 只有在有序列号时才添加sn字段
|
|
||||||
if self.sequence > 0:
|
|
||||||
identify_payload["body"]["sn"] = self.sequence
|
|
||||||
|
|
||||||
try:
|
|
||||||
message_str = json.dumps(identify_payload, ensure_ascii=False)
|
|
||||||
await self.ws.send(message_str)
|
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
|
||||||
logger.error(f"发送 IDENTIFY 信令时连接关闭: {e}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送 IDENTIFY 信令失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def heartbeat_loop(self):
|
|
||||||
try:
|
|
||||||
while self.running and self.ws:
|
|
||||||
await asyncio.sleep(self.heartbeat_interval)
|
|
||||||
|
|
||||||
if self.ws and not self._is_websocket_closed(self.ws):
|
|
||||||
try:
|
|
||||||
ping_payload = {
|
|
||||||
"op": 1, # PING
|
|
||||||
"body": {},
|
|
||||||
}
|
|
||||||
await self.ws.send(json.dumps(ping_payload, ensure_ascii=False))
|
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
|
||||||
logger.error(f"Satori WebSocket 连接关闭: {e}")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori WebSocket 发送心跳失败: {e}")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"心跳任务异常: {e}")
|
|
||||||
|
|
||||||
async def handle_message(self, message: str):
|
|
||||||
try:
|
|
||||||
data = json.loads(message)
|
|
||||||
op = data.get("op")
|
|
||||||
body = data.get("body", {})
|
|
||||||
|
|
||||||
if op == 4: # READY
|
|
||||||
self.logins = body.get("logins", [])
|
|
||||||
self.ready_received = True
|
|
||||||
|
|
||||||
# 输出连接成功的bot信息
|
|
||||||
if self.logins:
|
|
||||||
for i, login in enumerate(self.logins):
|
|
||||||
platform = login.get("platform", "")
|
|
||||||
user = login.get("user", {})
|
|
||||||
user_id = user.get("id", "")
|
|
||||||
user_name = user.get("name", "")
|
|
||||||
logger.info(
|
|
||||||
f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "sn" in body:
|
|
||||||
self.sequence = body["sn"]
|
|
||||||
|
|
||||||
elif op == 2: # PONG
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif op == 0: # EVENT
|
|
||||||
await self.handle_event(body)
|
|
||||||
if "sn" in body:
|
|
||||||
self.sequence = body["sn"]
|
|
||||||
|
|
||||||
elif op == 5: # META
|
|
||||||
if "sn" in body:
|
|
||||||
self.sequence = body["sn"]
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"解析 WebSocket 消息失败: {e}, 消息内容: {message}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理 WebSocket 消息异常: {e}")
|
|
||||||
|
|
||||||
async def handle_event(self, event_data: dict):
|
|
||||||
try:
|
|
||||||
event_type = event_data.get("type")
|
|
||||||
sn = event_data.get("sn")
|
|
||||||
if sn:
|
|
||||||
self.sequence = sn
|
|
||||||
|
|
||||||
if event_type == "message-created":
|
|
||||||
message = event_data.get("message", {})
|
|
||||||
user = event_data.get("user", {})
|
|
||||||
channel = event_data.get("channel", {})
|
|
||||||
guild = event_data.get("guild")
|
|
||||||
login = event_data.get("login", {})
|
|
||||||
timestamp = event_data.get("timestamp")
|
|
||||||
|
|
||||||
if user.get("id") == login.get("user", {}).get("id"):
|
|
||||||
return
|
|
||||||
|
|
||||||
abm = await self.convert_satori_message(
|
|
||||||
message, user, channel, guild, login, timestamp
|
|
||||||
)
|
|
||||||
if abm:
|
|
||||||
await self.handle_msg(abm)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理事件失败: {e}")
|
|
||||||
|
|
||||||
async def convert_satori_message(
|
|
||||||
self,
|
|
||||||
message: dict,
|
|
||||||
user: dict,
|
|
||||||
channel: dict,
|
|
||||||
guild: Optional[dict],
|
|
||||||
login: dict,
|
|
||||||
timestamp: Optional[int] = None,
|
|
||||||
) -> Optional[AstrBotMessage]:
|
|
||||||
try:
|
|
||||||
abm = AstrBotMessage()
|
|
||||||
abm.message_id = message.get("id", "")
|
|
||||||
abm.raw_message = {
|
|
||||||
"message": message,
|
|
||||||
"user": user,
|
|
||||||
"channel": channel,
|
|
||||||
"guild": guild,
|
|
||||||
"login": login,
|
|
||||||
}
|
|
||||||
|
|
||||||
if guild and guild.get("id"):
|
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
|
||||||
abm.group_id = guild.get("id", "")
|
|
||||||
abm.session_id = channel.get("id", "")
|
|
||||||
else:
|
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
|
||||||
abm.session_id = channel.get("id", "")
|
|
||||||
|
|
||||||
abm.sender = MessageMember(
|
|
||||||
user_id=user.get("id", ""),
|
|
||||||
nickname=user.get("nick", user.get("name", "")),
|
|
||||||
)
|
|
||||||
|
|
||||||
abm.self_id = login.get("user", {}).get("id", "")
|
|
||||||
|
|
||||||
# 消息链
|
|
||||||
abm.message = []
|
|
||||||
|
|
||||||
content = message.get("content", "")
|
|
||||||
|
|
||||||
quote = message.get("quote")
|
|
||||||
content_for_parsing = content # 副本
|
|
||||||
|
|
||||||
# 提取<quote>标签
|
|
||||||
if "<quote" in content:
|
|
||||||
try:
|
|
||||||
quote_info = await self._extract_quote_element(content)
|
|
||||||
if quote_info:
|
|
||||||
quote = quote_info["quote"]
|
|
||||||
content_for_parsing = quote_info["content_without_quote"]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
|
|
||||||
|
|
||||||
if quote:
|
|
||||||
# 引用消息
|
|
||||||
quote_abm = await self._convert_quote_message(quote)
|
|
||||||
if quote_abm:
|
|
||||||
sender_id = quote_abm.sender.user_id
|
|
||||||
if isinstance(sender_id, str) and sender_id.isdigit():
|
|
||||||
sender_id = int(sender_id)
|
|
||||||
elif not isinstance(sender_id, int):
|
|
||||||
sender_id = 0 # 默认值
|
|
||||||
|
|
||||||
reply_component = Reply(
|
|
||||||
id=quote_abm.message_id,
|
|
||||||
chain=quote_abm.message,
|
|
||||||
sender_id=quote_abm.sender.user_id,
|
|
||||||
sender_nickname=quote_abm.sender.nickname,
|
|
||||||
time=quote_abm.timestamp,
|
|
||||||
message_str=quote_abm.message_str,
|
|
||||||
text=quote_abm.message_str,
|
|
||||||
qq=sender_id,
|
|
||||||
)
|
|
||||||
abm.message.append(reply_component)
|
|
||||||
|
|
||||||
# 解析消息内容
|
|
||||||
content_elements = await self.parse_satori_elements(content_for_parsing)
|
|
||||||
abm.message.extend(content_elements)
|
|
||||||
|
|
||||||
abm.message_str = ""
|
|
||||||
for comp in content_elements:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
abm.message_str += comp.text
|
|
||||||
|
|
||||||
# 优先使用Satori事件中的时间戳
|
|
||||||
if timestamp is not None:
|
|
||||||
abm.timestamp = timestamp
|
|
||||||
else:
|
|
||||||
abm.timestamp = int(time.time())
|
|
||||||
|
|
||||||
return abm
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"转换 Satori 消息失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_namespace_prefixes(self, content: str) -> set:
|
|
||||||
"""提取XML内容中的命名空间前缀"""
|
|
||||||
prefixes = set()
|
|
||||||
|
|
||||||
# 查找所有标签
|
|
||||||
i = 0
|
|
||||||
while i < len(content):
|
|
||||||
# 查找开始标签
|
|
||||||
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
|
|
||||||
# 找到标签结束位置
|
|
||||||
tag_end = content.find(">", i)
|
|
||||||
if tag_end != -1:
|
|
||||||
# 提取标签内容
|
|
||||||
tag_content = content[i + 1 : tag_end]
|
|
||||||
# 检查是否有命名空间前缀
|
|
||||||
if ":" in tag_content and "xmlns:" not in tag_content:
|
|
||||||
# 分割标签名
|
|
||||||
parts = tag_content.split()
|
|
||||||
if parts:
|
|
||||||
tag_name = parts[0]
|
|
||||||
if ":" in tag_name:
|
|
||||||
prefix = tag_name.split(":")[0]
|
|
||||||
# 确保是有效的命名空间前缀
|
|
||||||
if (
|
|
||||||
prefix.isalnum()
|
|
||||||
or prefix.replace("_", "").isalnum()
|
|
||||||
):
|
|
||||||
prefixes.add(prefix)
|
|
||||||
i = tag_end + 1
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
# 查找结束标签
|
|
||||||
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
|
|
||||||
# 找到标签结束位置
|
|
||||||
tag_end = content.find(">", i)
|
|
||||||
if tag_end != -1:
|
|
||||||
# 提取标签内容
|
|
||||||
tag_content = content[i + 2 : tag_end]
|
|
||||||
# 检查是否有命名空间前缀
|
|
||||||
if ":" in tag_content:
|
|
||||||
prefix = tag_content.split(":")[0]
|
|
||||||
# 确保是有效的命名空间前缀
|
|
||||||
if prefix.isalnum() or prefix.replace("_", "").isalnum():
|
|
||||||
prefixes.add(prefix)
|
|
||||||
i = tag_end + 1
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return prefixes
|
|
||||||
|
|
||||||
async def _extract_quote_element(self, content: str) -> Optional[dict]:
|
|
||||||
"""提取<quote>标签信息"""
|
|
||||||
try:
|
|
||||||
# 处理命名空间前缀问题
|
|
||||||
processed_content = content
|
|
||||||
if ":" in content and not content.startswith("<root"):
|
|
||||||
prefixes = self._extract_namespace_prefixes(content)
|
|
||||||
|
|
||||||
# 构建命名空间声明
|
|
||||||
ns_declarations = " ".join(
|
|
||||||
[
|
|
||||||
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
|
||||||
for prefix in prefixes
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 包装内容
|
|
||||||
processed_content = f"<root {ns_declarations}>{content}</root>"
|
|
||||||
elif not content.startswith("<root"):
|
|
||||||
processed_content = f"<root>{content}</root>"
|
|
||||||
else:
|
|
||||||
processed_content = content
|
|
||||||
|
|
||||||
root = ET.fromstring(processed_content)
|
|
||||||
|
|
||||||
# 查找<quote>标签
|
|
||||||
quote_element = None
|
|
||||||
for elem in root.iter():
|
|
||||||
tag_name = elem.tag
|
|
||||||
if "}" in tag_name:
|
|
||||||
tag_name = tag_name.split("}")[1]
|
|
||||||
if tag_name.lower() == "quote":
|
|
||||||
quote_element = elem
|
|
||||||
break
|
|
||||||
|
|
||||||
if quote_element is not None:
|
|
||||||
# 提取quote标签的属性
|
|
||||||
quote_id = quote_element.get("id", "")
|
|
||||||
|
|
||||||
# 提取<quote>标签内部的内容
|
|
||||||
inner_content = ""
|
|
||||||
if quote_element.text:
|
|
||||||
inner_content += quote_element.text
|
|
||||||
for child in quote_element:
|
|
||||||
inner_content += ET.tostring(
|
|
||||||
child, encoding="unicode", method="xml"
|
|
||||||
)
|
|
||||||
if child.tail:
|
|
||||||
inner_content += child.tail
|
|
||||||
|
|
||||||
# 构造移除了<quote>标签的内容
|
|
||||||
content_without_quote = content.replace(
|
|
||||||
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"quote": {"id": quote_id, "content": inner_content},
|
|
||||||
"content_without_quote": content_without_quote,
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"提取<quote>标签时发生错误: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
|
|
||||||
"""转换引用消息"""
|
|
||||||
try:
|
|
||||||
quote_abm = AstrBotMessage()
|
|
||||||
quote_abm.message_id = quote.get("id", "")
|
|
||||||
|
|
||||||
# 解析引用消息的发送者
|
|
||||||
quote_author = quote.get("author", {})
|
|
||||||
if quote_author:
|
|
||||||
quote_abm.sender = MessageMember(
|
|
||||||
user_id=quote_author.get("id", ""),
|
|
||||||
nickname=quote_author.get("nick", quote_author.get("name", "")),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果没有作者信息,使用默认值
|
|
||||||
quote_abm.sender = MessageMember(
|
|
||||||
user_id=quote.get("user_id", ""),
|
|
||||||
nickname="内容",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 解析引用消息内容
|
|
||||||
quote_content = quote.get("content", "")
|
|
||||||
quote_abm.message = await self.parse_satori_elements(quote_content)
|
|
||||||
|
|
||||||
quote_abm.message_str = ""
|
|
||||||
for comp in quote_abm.message:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
quote_abm.message_str += comp.text
|
|
||||||
|
|
||||||
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
|
|
||||||
|
|
||||||
# 如果没有任何内容,使用默认文本
|
|
||||||
if not quote_abm.message_str.strip():
|
|
||||||
quote_abm.message_str = "[引用消息]"
|
|
||||||
|
|
||||||
return quote_abm
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"转换引用消息失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def parse_satori_elements(self, content: str) -> list:
|
|
||||||
"""解析 Satori 消息元素"""
|
|
||||||
elements = []
|
|
||||||
|
|
||||||
if not content:
|
|
||||||
return elements
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 处理命名空间前缀问题
|
|
||||||
processed_content = content
|
|
||||||
if ":" in content and not content.startswith("<root"):
|
|
||||||
prefixes = self._extract_namespace_prefixes(content)
|
|
||||||
|
|
||||||
# 构建命名空间声明
|
|
||||||
ns_declarations = " ".join(
|
|
||||||
[
|
|
||||||
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
|
||||||
for prefix in prefixes
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 包装内容
|
|
||||||
processed_content = f"<root {ns_declarations}>{content}</root>"
|
|
||||||
elif not content.startswith("<root"):
|
|
||||||
processed_content = f"<root>{content}</root>"
|
|
||||||
else:
|
|
||||||
processed_content = content
|
|
||||||
|
|
||||||
root = ET.fromstring(processed_content)
|
|
||||||
await self._parse_xml_node(root, elements)
|
|
||||||
except ET.ParseError as e:
|
|
||||||
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
|
||||||
# 如果解析失败,将整个内容当作纯文本
|
|
||||||
if content.strip():
|
|
||||||
elements.append(Plain(text=content))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# 如果没有解析到任何元素,将整个内容当作纯文本
|
|
||||||
if not elements and content.strip():
|
|
||||||
elements.append(Plain(text=content))
|
|
||||||
|
|
||||||
return elements
|
|
||||||
|
|
||||||
async def _parse_xml_node(self, node: ET.Element, elements: list) -> None:
|
|
||||||
"""递归解析 XML 节点"""
|
|
||||||
if node.text and node.text.strip():
|
|
||||||
elements.append(Plain(text=node.text))
|
|
||||||
|
|
||||||
for child in node:
|
|
||||||
# 获取标签名,去除命名空间前缀
|
|
||||||
tag_name = child.tag
|
|
||||||
if "}" in tag_name:
|
|
||||||
tag_name = tag_name.split("}")[1]
|
|
||||||
tag_name = tag_name.lower()
|
|
||||||
|
|
||||||
attrs = child.attrib
|
|
||||||
|
|
||||||
if tag_name == "at":
|
|
||||||
user_id = attrs.get("id") or attrs.get("name", "")
|
|
||||||
elements.append(At(qq=user_id, name=user_id))
|
|
||||||
|
|
||||||
elif tag_name in ("img", "image"):
|
|
||||||
src = attrs.get("src", "")
|
|
||||||
if not src:
|
|
||||||
continue
|
|
||||||
elements.append(Image(file=src))
|
|
||||||
|
|
||||||
elif tag_name == "file":
|
|
||||||
src = attrs.get("src", "")
|
|
||||||
name = attrs.get("name", "文件")
|
|
||||||
if src:
|
|
||||||
elements.append(File(name=name, file=src))
|
|
||||||
|
|
||||||
elif tag_name in ("audio", "record"):
|
|
||||||
src = attrs.get("src", "")
|
|
||||||
if not src:
|
|
||||||
continue
|
|
||||||
elements.append(Record(file=src))
|
|
||||||
|
|
||||||
elif tag_name == "quote":
|
|
||||||
# quote标签已经被特殊处理
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif tag_name == "face":
|
|
||||||
face_id = attrs.get("id", "")
|
|
||||||
face_name = attrs.get("name", "")
|
|
||||||
face_type = attrs.get("type", "")
|
|
||||||
|
|
||||||
if face_name:
|
|
||||||
elements.append(Plain(text=f"[表情:{face_name}]"))
|
|
||||||
elif face_id and face_type:
|
|
||||||
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
|
|
||||||
elif face_id:
|
|
||||||
elements.append(Plain(text=f"[表情ID:{face_id}]"))
|
|
||||||
else:
|
|
||||||
elements.append(Plain(text="[表情]"))
|
|
||||||
|
|
||||||
elif tag_name == "ark":
|
|
||||||
# 作为纯文本添加到消息链中
|
|
||||||
data = attrs.get("data", "")
|
|
||||||
if data:
|
|
||||||
import html
|
|
||||||
|
|
||||||
decoded_data = html.unescape(data)
|
|
||||||
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
|
||||||
else:
|
|
||||||
elements.append(Plain(text="[ARK卡片]"))
|
|
||||||
|
|
||||||
elif tag_name == "json":
|
|
||||||
# JSON标签 视为ARK卡片消息
|
|
||||||
data = attrs.get("data", "")
|
|
||||||
if data:
|
|
||||||
import html
|
|
||||||
|
|
||||||
decoded_data = html.unescape(data)
|
|
||||||
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
|
||||||
else:
|
|
||||||
elements.append(Plain(text="[JSON卡片]"))
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 未知标签,递归处理其内容
|
|
||||||
if child.text and child.text.strip():
|
|
||||||
elements.append(Plain(text=child.text))
|
|
||||||
await self._parse_xml_node(child, elements)
|
|
||||||
|
|
||||||
# 处理标签后的文本
|
|
||||||
if child.tail and child.tail.strip():
|
|
||||||
elements.append(Plain(text=child.tail))
|
|
||||||
|
|
||||||
async def handle_msg(self, message: AstrBotMessage):
|
|
||||||
from .satori_event import SatoriPlatformEvent
|
|
||||||
|
|
||||||
message_event = SatoriPlatformEvent(
|
|
||||||
message_str=message.message_str,
|
|
||||||
message_obj=message,
|
|
||||||
platform_meta=self.meta(),
|
|
||||||
session_id=message.session_id,
|
|
||||||
adapter=self,
|
|
||||||
)
|
|
||||||
self.commit_event(message_event)
|
|
||||||
|
|
||||||
async def send_http_request(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
path: str,
|
|
||||||
data: dict | None = None,
|
|
||||||
platform: str | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> dict:
|
|
||||||
if not self.session:
|
|
||||||
raise Exception("HTTP session 未初始化")
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.token}"
|
|
||||||
|
|
||||||
if platform and user_id:
|
|
||||||
headers["satori-platform"] = platform
|
|
||||||
headers["satori-user-id"] = user_id
|
|
||||||
elif self.logins:
|
|
||||||
current_login = self.logins[0]
|
|
||||||
headers["satori-platform"] = current_login.get("platform", "")
|
|
||||||
user = current_login.get("user", {})
|
|
||||||
headers["satori-user-id"] = user.get("id", "") if user else ""
|
|
||||||
|
|
||||||
if not path.startswith("/"):
|
|
||||||
path = "/" + path
|
|
||||||
|
|
||||||
# 使用新的API地址配置
|
|
||||||
url = f"{self.api_base_url.rstrip('/')}{path}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.session.request(
|
|
||||||
method, url, json=data, headers=headers
|
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
result = await response.json()
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori HTTP 请求异常: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def terminate(self):
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
if self.heartbeat_task:
|
|
||||||
self.heartbeat_task.cancel()
|
|
||||||
|
|
||||||
if self.ws:
|
|
||||||
try:
|
|
||||||
await self.ws.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori WebSocket 关闭异常: {e}")
|
|
||||||
|
|
||||||
if self.session:
|
|
||||||
await self.session.close()
|
|
||||||
@@ -1,230 +0,0 @@
|
|||||||
from typing import TYPE_CHECKING
|
|
||||||
from astrbot.api import logger
|
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
|
||||||
from astrbot.api.message_components import Plain, Image, At, File, Record
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .satori_adapter import SatoriPlatformAdapter
|
|
||||||
|
|
||||||
|
|
||||||
class SatoriPlatformEvent(AstrMessageEvent):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message_str: str,
|
|
||||||
message_obj: AstrBotMessage,
|
|
||||||
platform_meta: PlatformMetadata,
|
|
||||||
session_id: str,
|
|
||||||
adapter: "SatoriPlatformAdapter",
|
|
||||||
):
|
|
||||||
# 更新平台元数据
|
|
||||||
if adapter and hasattr(adapter, "logins") and adapter.logins:
|
|
||||||
current_login = adapter.logins[0]
|
|
||||||
platform_name = current_login.get("platform", "satori")
|
|
||||||
user = current_login.get("user", {})
|
|
||||||
user_id = user.get("id", "") if user else ""
|
|
||||||
if not platform_meta.id and user_id:
|
|
||||||
platform_meta.id = f"{platform_name}({user_id})"
|
|
||||||
|
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
||||||
self.adapter = adapter
|
|
||||||
self.platform = None
|
|
||||||
self.user_id = None
|
|
||||||
if (
|
|
||||||
hasattr(message_obj, "raw_message")
|
|
||||||
and message_obj.raw_message
|
|
||||||
and isinstance(message_obj.raw_message, dict)
|
|
||||||
):
|
|
||||||
login = message_obj.raw_message.get("login", {})
|
|
||||||
self.platform = login.get("platform")
|
|
||||||
user = login.get("user", {})
|
|
||||||
self.user_id = user.get("id") if user else None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def send_with_adapter(
|
|
||||||
cls, adapter: "SatoriPlatformAdapter", message: MessageChain, session_id: str
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
content_parts = []
|
|
||||||
|
|
||||||
for component in message.chain:
|
|
||||||
if isinstance(component, Plain):
|
|
||||||
text = (
|
|
||||||
component.text.replace("&", "&")
|
|
||||||
.replace("<", "<")
|
|
||||||
.replace(">", ">")
|
|
||||||
)
|
|
||||||
content_parts.append(text)
|
|
||||||
|
|
||||||
elif isinstance(component, At):
|
|
||||||
if component.qq:
|
|
||||||
content_parts.append(f'<at id="{component.qq}"/>')
|
|
||||||
elif component.name:
|
|
||||||
content_parts.append(f'<at name="{component.name}"/>')
|
|
||||||
|
|
||||||
elif isinstance(component, Image):
|
|
||||||
try:
|
|
||||||
image_base64 = await component.convert_to_base64()
|
|
||||||
if image_base64:
|
|
||||||
content_parts.append(
|
|
||||||
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图片转换为base64失败: {e}")
|
|
||||||
|
|
||||||
elif isinstance(component, File):
|
|
||||||
content_parts.append(
|
|
||||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(component, Record):
|
|
||||||
try:
|
|
||||||
record_base64 = await component.convert_to_base64()
|
|
||||||
if record_base64:
|
|
||||||
content_parts.append(
|
|
||||||
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"语音转换为base64失败: {e}")
|
|
||||||
|
|
||||||
content = "".join(content_parts)
|
|
||||||
channel_id = session_id
|
|
||||||
data = {"channel_id": channel_id, "content": content}
|
|
||||||
|
|
||||||
platform = None
|
|
||||||
user_id = None
|
|
||||||
|
|
||||||
if hasattr(adapter, "logins") and adapter.logins:
|
|
||||||
current_login = adapter.logins[0]
|
|
||||||
platform = current_login.get("platform", "")
|
|
||||||
user = current_login.get("user", {})
|
|
||||||
user_id = user.get("id", "") if user else ""
|
|
||||||
|
|
||||||
result = await adapter.send_http_request(
|
|
||||||
"POST", "/message.create", data, platform, user_id
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori 消息发送异常: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
|
||||||
platform = getattr(self, "platform", None)
|
|
||||||
user_id = getattr(self, "user_id", None)
|
|
||||||
|
|
||||||
if not platform or not user_id:
|
|
||||||
if hasattr(self.adapter, "logins") and self.adapter.logins:
|
|
||||||
current_login = self.adapter.logins[0]
|
|
||||||
platform = current_login.get("platform", "")
|
|
||||||
user = current_login.get("user", {})
|
|
||||||
user_id = user.get("id", "") if user else ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
content_parts = []
|
|
||||||
|
|
||||||
for component in message.chain:
|
|
||||||
if isinstance(component, Plain):
|
|
||||||
text = (
|
|
||||||
component.text.replace("&", "&")
|
|
||||||
.replace("<", "<")
|
|
||||||
.replace(">", ">")
|
|
||||||
)
|
|
||||||
content_parts.append(text)
|
|
||||||
|
|
||||||
elif isinstance(component, At):
|
|
||||||
if component.qq:
|
|
||||||
content_parts.append(f'<at id="{component.qq}"/>')
|
|
||||||
elif component.name:
|
|
||||||
content_parts.append(f'<at name="{component.name}"/>')
|
|
||||||
|
|
||||||
elif isinstance(component, Image):
|
|
||||||
try:
|
|
||||||
image_base64 = await component.convert_to_base64()
|
|
||||||
if image_base64:
|
|
||||||
content_parts.append(
|
|
||||||
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图片转换为base64失败: {e}")
|
|
||||||
|
|
||||||
elif isinstance(component, File):
|
|
||||||
content_parts.append(
|
|
||||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(component, Record):
|
|
||||||
try:
|
|
||||||
record_base64 = await component.convert_to_base64()
|
|
||||||
if record_base64:
|
|
||||||
content_parts.append(
|
|
||||||
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"语音转换为base64失败: {e}")
|
|
||||||
|
|
||||||
content = "".join(content_parts)
|
|
||||||
channel_id = self.session_id
|
|
||||||
data = {"channel_id": channel_id, "content": content}
|
|
||||||
|
|
||||||
result = await self.adapter.send_http_request(
|
|
||||||
"POST", "/message.create", data, platform, user_id
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
logger.error("Satori 消息发送失败")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori 消息发送异常: {e}")
|
|
||||||
|
|
||||||
await super().send(message)
|
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
|
||||||
try:
|
|
||||||
content_parts = []
|
|
||||||
|
|
||||||
async for chain in generator:
|
|
||||||
if isinstance(chain, MessageChain):
|
|
||||||
if chain.type == "break":
|
|
||||||
if content_parts:
|
|
||||||
content = "".join(content_parts)
|
|
||||||
temp_chain = MessageChain([Plain(text=content)])
|
|
||||||
await self.send(temp_chain)
|
|
||||||
content_parts = []
|
|
||||||
continue
|
|
||||||
|
|
||||||
for component in chain.chain:
|
|
||||||
if isinstance(component, Plain):
|
|
||||||
content_parts.append(component.text)
|
|
||||||
elif isinstance(component, Image):
|
|
||||||
if content_parts:
|
|
||||||
content = "".join(content_parts)
|
|
||||||
temp_chain = MessageChain([Plain(text=content)])
|
|
||||||
await self.send(temp_chain)
|
|
||||||
content_parts = []
|
|
||||||
try:
|
|
||||||
image_base64 = await component.convert_to_base64()
|
|
||||||
if image_base64:
|
|
||||||
img_chain = MessageChain(
|
|
||||||
[
|
|
||||||
Plain(
|
|
||||||
text=f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
await self.send(img_chain)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图片转换为base64失败: {e}")
|
|
||||||
else:
|
|
||||||
content_parts.append(str(component))
|
|
||||||
|
|
||||||
if content_parts:
|
|
||||||
content = "".join(content_parts)
|
|
||||||
temp_chain = MessageChain([Plain(text=content)])
|
|
||||||
await self.send(temp_chain)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Satori 流式消息发送异常: {e}")
|
|
||||||
|
|
||||||
return await super().send_streaming(generator, use_fallback)
|
|
||||||
@@ -308,9 +308,7 @@ class SlackAdapter(Platform):
|
|||||||
base64_content = base64.b64encode(content).decode("utf-8")
|
base64_content = base64.b64encode(content).decode("utf-8")
|
||||||
return base64_content
|
return base64_content
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
||||||
f"Failed to download slack file: {resp.status} {await resp.text()}"
|
|
||||||
)
|
|
||||||
raise Exception(f"下载文件失败: {resp.status}")
|
raise Exception(f"下载文件失败: {resp.status}")
|
||||||
|
|
||||||
async def run(self) -> Awaitable[Any]:
|
async def run(self) -> Awaitable[Any]:
|
||||||
|
|||||||
@@ -75,13 +75,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||||
}
|
}
|
||||||
file_url = response["files"][0]["permalink"]
|
file_url = response["files"][0]["permalink"]
|
||||||
return {
|
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
||||||
"type": "section",
|
|
||||||
"text": {
|
|
||||||
"type": "mrkdwn",
|
|
||||||
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||||
|
|
||||||
|
|||||||
@@ -95,8 +95,9 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def meta(self) -> PlatformMetadata:
|
def meta(self) -> PlatformMetadata:
|
||||||
id_ = self.config.get("id") or "telegram"
|
return PlatformMetadata(
|
||||||
return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_)
|
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def run(self):
|
async def run(self):
|
||||||
@@ -116,10 +117,6 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
|
|
||||||
if not self.application.updater:
|
|
||||||
logger.error("Telegram Updater is not initialized. Cannot start polling.")
|
|
||||||
return
|
|
||||||
|
|
||||||
queue = self.application.updater.start_polling()
|
queue = self.application.updater.start_polling()
|
||||||
logger.info("Telegram Platform Adapter is running.")
|
logger.info("Telegram Platform Adapter is running.")
|
||||||
await queue
|
await queue
|
||||||
@@ -186,6 +183,7 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
||||||
|
logger.debug(f"跳过无法注册的命令: {cmd_name}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build description.
|
# Build description.
|
||||||
@@ -197,11 +195,6 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
return cmd_name, description
|
return cmd_name, description
|
||||||
|
|
||||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
if not update.effective_chat:
|
|
||||||
logger.warning(
|
|
||||||
"Received a start command without an effective chat, skipping /start reply."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
await context.bot.send_message(
|
await context.bot.send_message(
|
||||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||||
)
|
)
|
||||||
@@ -214,20 +207,15 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
|
|
||||||
async def convert_message(
|
async def convert_message(
|
||||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
||||||
) -> AstrBotMessage | None:
|
) -> AstrBotMessage:
|
||||||
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
||||||
|
|
||||||
@param update: Telegram 的 Update 对象。
|
@param update: Telegram 的 Update 对象。
|
||||||
@param context: Telegram 的 Context 对象。
|
@param context: Telegram 的 Context 对象。
|
||||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
"""
|
"""
|
||||||
if not update.message:
|
|
||||||
logger.warning("Received an update without a message.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
message = AstrBotMessage()
|
message = AstrBotMessage()
|
||||||
message.session_id = str(update.message.chat.id)
|
message.session_id = str(update.message.chat.id)
|
||||||
|
|
||||||
# 获得是群聊还是私聊
|
# 获得是群聊还是私聊
|
||||||
if update.message.chat.type == ChatType.PRIVATE:
|
if update.message.chat.type == ChatType.PRIVATE:
|
||||||
message.type = MessageType.FRIEND_MESSAGE
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
@@ -238,13 +226,10 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
# Topic Group
|
# Topic Group
|
||||||
message.group_id += "#" + str(update.message.message_thread_id)
|
message.group_id += "#" + str(update.message.message_thread_id)
|
||||||
message.session_id = message.group_id
|
message.session_id = message.group_id
|
||||||
|
|
||||||
message.message_id = str(update.message.message_id)
|
message.message_id = str(update.message.message_id)
|
||||||
_from_user = update.message.from_user
|
|
||||||
if not _from_user:
|
|
||||||
logger.warning("[Telegram] Received a message without a from_user.")
|
|
||||||
return None
|
|
||||||
message.sender = MessageMember(
|
message.sender = MessageMember(
|
||||||
str(_from_user.id), _from_user.username or "Unknown"
|
str(update.message.from_user.id), update.message.from_user.username
|
||||||
)
|
)
|
||||||
message.self_id = str(context.bot.username)
|
message.self_id = str(context.bot.username)
|
||||||
message.raw_message = update
|
message.raw_message = update
|
||||||
@@ -263,32 +248,22 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
reply_abm = await self.convert_message(reply_update, context, False)
|
reply_abm = await self.convert_message(reply_update, context, False)
|
||||||
|
|
||||||
if reply_abm:
|
message.message.append(
|
||||||
message.message.append(
|
Comp.Reply(
|
||||||
Comp.Reply(
|
id=reply_abm.message_id,
|
||||||
id=reply_abm.message_id,
|
chain=reply_abm.message,
|
||||||
chain=reply_abm.message,
|
sender_id=reply_abm.sender.user_id,
|
||||||
sender_id=reply_abm.sender.user_id,
|
sender_nickname=reply_abm.sender.nickname,
|
||||||
sender_nickname=reply_abm.sender.nickname,
|
time=reply_abm.timestamp,
|
||||||
time=reply_abm.timestamp,
|
message_str=reply_abm.message_str,
|
||||||
message_str=reply_abm.message_str,
|
text=reply_abm.message_str,
|
||||||
text=reply_abm.message_str,
|
qq=reply_abm.sender.user_id,
|
||||||
qq=reply_abm.sender.user_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if update.message.text:
|
if update.message.text:
|
||||||
# 处理文本消息
|
# 处理文本消息
|
||||||
plain_text = update.message.text
|
plain_text = update.message.text
|
||||||
if (
|
|
||||||
message.type == MessageType.GROUP_MESSAGE
|
|
||||||
and update.message
|
|
||||||
and update.message.reply_to_message
|
|
||||||
and update.message.reply_to_message.from_user
|
|
||||||
and update.message.reply_to_message.from_user.id == context.bot.id
|
|
||||||
):
|
|
||||||
plain_text2 = f"/@{context.bot.username} " + plain_text
|
|
||||||
plain_text = plain_text2
|
|
||||||
|
|
||||||
# 群聊场景命令特殊处理
|
# 群聊场景命令特殊处理
|
||||||
if plain_text.startswith("/"):
|
if plain_text.startswith("/"):
|
||||||
@@ -354,25 +329,15 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
|
|
||||||
elif update.message.document:
|
elif update.message.document:
|
||||||
file = await update.message.document.get_file()
|
file = await update.message.document.get_file()
|
||||||
file_name = update.message.document.file_name or uuid.uuid4().hex
|
message.message = [
|
||||||
file_path = file.file_path
|
Comp.File(file=file.file_path, name=update.message.document.file_name),
|
||||||
if file_path is None:
|
]
|
||||||
logger.warning(
|
|
||||||
f"Telegram document file_path is None, cannot save the file {file_name}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
|
||||||
|
|
||||||
elif update.message.video:
|
elif update.message.video:
|
||||||
file = await update.message.video.get_file()
|
file = await update.message.video.get_file()
|
||||||
file_name = update.message.video.file_name or uuid.uuid4().hex
|
message.message = [
|
||||||
file_path = file.file_path
|
Comp.Video(file=file.file_path, path=file.file_path),
|
||||||
if file_path is None:
|
]
|
||||||
logger.warning(
|
|
||||||
f"Telegram video file_path is None, cannot save the file {file_name}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message.message.append(Comp.Video(file=file_path, path=file.file_path))
|
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from telegram.ext import ExtBot
|
|||||||
from astrbot.core.utils.io import download_file
|
from astrbot.core.utils.io import download_file
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
from telegram import ReactionTypeEmoji, ReactionTypeCustomEmoji
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramPlatformEvent(AstrMessageEvent):
|
class TelegramPlatformEvent(AstrMessageEvent):
|
||||||
@@ -67,9 +66,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def send_with_client(
|
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
|
||||||
cls, client: ExtBot, message: MessageChain, user_name: str
|
|
||||||
):
|
|
||||||
image_path = None
|
image_path = None
|
||||||
|
|
||||||
has_reply = False
|
has_reply = False
|
||||||
@@ -136,39 +133,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
async def react(self, emoji: str | None, big: bool = False):
|
|
||||||
"""
|
|
||||||
给原消息添加 Telegram 反应:
|
|
||||||
- 普通 emoji:传入 '👍'、'😂' 等
|
|
||||||
- 自定义表情:传入其 custom_emoji_id(纯数字字符串)
|
|
||||||
- 取消本机器人的反应:传入 None 或空字符串
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 解析 chat_id(去掉超级群的 "#<thread_id>" 片段)
|
|
||||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
|
||||||
chat_id = (self.message_obj.group_id or "").split("#")[0]
|
|
||||||
else:
|
|
||||||
chat_id = self.get_sender_id()
|
|
||||||
|
|
||||||
message_id = int(self.message_obj.message_id)
|
|
||||||
|
|
||||||
# 组装 reaction 参数(必须是 ReactionType 的列表)
|
|
||||||
if not emoji: # 清空本 bot 的反应
|
|
||||||
reaction_param = [] # 空列表表示移除本 bot 的反应
|
|
||||||
elif emoji.isdigit(): # 自定义表情:传 custom_emoji_id
|
|
||||||
reaction_param = [ReactionTypeCustomEmoji(emoji)]
|
|
||||||
else: # 普通 emoji
|
|
||||||
reaction_param = [ReactionTypeEmoji(emoji)]
|
|
||||||
|
|
||||||
await self.client.set_message_reaction(
|
|
||||||
chat_id=chat_id,
|
|
||||||
message_id=message_id,
|
|
||||||
reaction=reaction_param, # 注意是列表
|
|
||||||
is_big=big, # 可选:大动画
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Telegram] 添加反应失败: {e}")
|
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
message_thread_id = None
|
message_thread_id = None
|
||||||
|
|
||||||
@@ -252,6 +216,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
msg = await self.client.send_message(text=delta, **payload)
|
msg = await self.client.send_message(text=delta, **payload)
|
||||||
current_content = delta
|
current_content = delta
|
||||||
|
delta = ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
message_id = msg.message_id
|
message_id = msg.message_id
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class WebChatQueueMgr:
|
class WebChatQueueMgr:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.queues = {}
|
self.queues = {}
|
||||||
@@ -31,5 +30,4 @@ class WebChatQueueMgr:
|
|||||||
"""Check if a queue exists for the given conversation ID"""
|
"""Check if a queue exists for the given conversation ID"""
|
||||||
return conversation_id in self.queues
|
return conversation_id in self.queues
|
||||||
|
|
||||||
|
|
||||||
webchat_queue_mgr = WebChatQueueMgr()
|
webchat_queue_mgr = WebChatQueueMgr()
|
||||||
|
|||||||
@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
|
|||||||
def _extract_auth_key(self, data):
|
def _extract_auth_key(self, data):
|
||||||
"""Helper method to extract auth_key from response data."""
|
"""Helper method to extract auth_key from response data."""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
auth_keys = data.get("authKeys") # 新接口
|
auth_keys = data.get("authKeys") # 新接口
|
||||||
if isinstance(auth_keys, list) and auth_keys:
|
if isinstance(auth_keys, list) and auth_keys:
|
||||||
return auth_keys[0]
|
return auth_keys[0]
|
||||||
elif isinstance(data, list) and data: # 旧接口
|
elif isinstance(data, list) and data: # 旧接口
|
||||||
return data[0]
|
return data[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -234,9 +234,7 @@ class WeChatPadProAdapter(Platform):
|
|||||||
try:
|
try:
|
||||||
async with session.post(url, params=params, json=payload) as response:
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
logger.error(
|
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
|
||||||
f"生成授权码失败: {response.status}, {await response.text()}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
@@ -247,9 +245,7 @@ class WeChatPadProAdapter(Platform):
|
|||||||
if self.auth_key:
|
if self.auth_key:
|
||||||
logger.info("成功获取授权码")
|
logger.info("成功获取授权码")
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
|
||||||
f"生成授权码成功但未找到授权码: {response_data}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"生成授权码失败: {response_data}")
|
logger.error(f"生成授权码失败: {response_data}")
|
||||||
except aiohttp.ClientConnectorError as e:
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
|||||||
@@ -185,7 +185,6 @@ class WecomPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"wecom",
|
"wecom",
|
||||||
"wecom 适配器",
|
"wecom 适配器",
|
||||||
id=self.config.get("id", "wecom"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -48,12 +48,7 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||||
:return: 接口调用结果
|
:return: 接口调用结果
|
||||||
"""
|
"""
|
||||||
data = {
|
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||||
"token": token,
|
|
||||||
"cursor": cursor,
|
|
||||||
"limit": limit,
|
|
||||||
"open_kfid": open_kfid,
|
|
||||||
}
|
|
||||||
return self._post("kf/sync_msg", data=data)
|
return self._post("kf/sync_msg", data=data)
|
||||||
|
|
||||||
def get_service_state(self, open_kfid, external_userid):
|
def get_service_state(self, open_kfid, external_userid):
|
||||||
@@ -77,9 +72,7 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
}
|
}
|
||||||
return self._post("kf/service_state/get", data=data)
|
return self._post("kf/service_state/get", data=data)
|
||||||
|
|
||||||
def trans_service_state(
|
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||||
self, open_kfid, external_userid, service_state, servicer_userid=""
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
变更会话状态
|
变更会话状态
|
||||||
|
|
||||||
@@ -187,9 +180,7 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
"""
|
"""
|
||||||
return self._get("kf/customer/get_upgrade_service_config")
|
return self._get("kf/customer/get_upgrade_service_config")
|
||||||
|
|
||||||
def upgrade_service(
|
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||||
self, open_kfid, external_userid, service_type, member=None, groupchat=None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
为客户升级为专员或客户群服务
|
为客户升级为专员或客户群服务
|
||||||
|
|
||||||
@@ -255,9 +246,7 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||||
return self._post("kf/get_corp_statistic", data=data)
|
return self._post("kf/get_corp_statistic", data=data)
|
||||||
|
|
||||||
def get_servicer_statistic(
|
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||||
self, start_time, end_time, open_kfid=None, servicer_userid=None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
获取「客户数据统计」接待人员明细数据
|
获取「客户数据统计」接待人员明细数据
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from optionaldict import optionaldict
|
|||||||
|
|
||||||
from wechatpy.client.api.base import BaseWeChatAPI
|
from wechatpy.client.api.base import BaseWeChatAPI
|
||||||
|
|
||||||
|
|
||||||
class WeChatKFMessage(BaseWeChatAPI):
|
class WeChatKFMessage(BaseWeChatAPI):
|
||||||
"""
|
"""
|
||||||
发送微信客服消息
|
发送微信客服消息
|
||||||
@@ -126,55 +125,35 @@ class WeChatKFMessage(BaseWeChatAPI):
|
|||||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_msgmenu(
|
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||||
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
|
|
||||||
):
|
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "msgmenu",
|
"msgtype": "msgmenu",
|
||||||
"msgmenu": {
|
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||||
"head_content": head_content,
|
|
||||||
"list": menu_list,
|
|
||||||
"tail_content": tail_content,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_location(
|
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||||
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
|
|
||||||
):
|
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "location",
|
"msgtype": "location",
|
||||||
"msgmenu": {
|
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||||
"name": name,
|
|
||||||
"address": address,
|
|
||||||
"latitude": latitude,
|
|
||||||
"longitude": longitude,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_miniprogram(
|
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||||
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
|
|
||||||
):
|
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "miniprogram",
|
"msgtype": "miniprogram",
|
||||||
"msgmenu": {
|
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||||
"appid": appid,
|
|
||||||
"title": title,
|
|
||||||
"thumb_media_id": thumb_media_id,
|
|
||||||
"pagepath": pagepath,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -160,9 +160,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
self.wexin_event_workers[msg.id] = future
|
self.wexin_event_workers[msg.id] = future
|
||||||
await self.convert_message(msg, future)
|
await self.convert_message(msg, future)
|
||||||
# I love shield so much!
|
# I love shield so much!
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
||||||
asyncio.shield(future), 60
|
|
||||||
) # wait for 60s
|
|
||||||
logger.debug(f"Got future result: {result}")
|
logger.debug(f"Got future result: {result}")
|
||||||
self.wexin_event_workers.pop(msg.id, None)
|
self.wexin_event_workers.pop(msg.id, None)
|
||||||
return result # xml. see weixin_offacc_event.py
|
return result # xml. see weixin_offacc_event.py
|
||||||
@@ -184,7 +182,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"weixin_official_account",
|
"weixin_official_account",
|
||||||
"微信公众平台 适配器",
|
"微信公众平台 适配器",
|
||||||
id=self.config.get("id", "weixin_official_account"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -150,6 +150,7 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|||||||
return
|
return
|
||||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||||
|
|
||||||
|
|
||||||
if active_send_mode:
|
if active_send_mode:
|
||||||
self.client.message.send_voice(
|
self.client.message.send_voice(
|
||||||
message_obj.sender.user_id,
|
message_obj.sender.user_id,
|
||||||
|
|||||||
@@ -4,11 +4,9 @@ import json
|
|||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Dict, Type, Any
|
from typing import List, Dict, Type
|
||||||
from astrbot.core.agent.tool import ToolSet
|
from astrbot.core.agent.tool import ToolSet
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
from google.genai.types import GenerateContentResponse
|
|
||||||
from anthropic.types import Message
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
from openai.types.chat.chat_completion_message_tool_call import (
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
)
|
)
|
||||||
@@ -32,11 +30,11 @@ class ProviderMetaData:
|
|||||||
desc: str = ""
|
desc: str = ""
|
||||||
"""提供商适配器描述."""
|
"""提供商适配器描述."""
|
||||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||||
cls_type: Type | None = None
|
cls_type: Type = None
|
||||||
|
|
||||||
default_config_tmpl: dict | None = None
|
default_config_tmpl: dict = None
|
||||||
"""平台的默认配置模板"""
|
"""平台的默认配置模板"""
|
||||||
provider_display_name: str | None = None
|
provider_display_name: str = None
|
||||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||||
|
|
||||||
|
|
||||||
@@ -60,21 +58,18 @@ class ToolCallMessageSegment:
|
|||||||
class AssistantMessageSegment:
|
class AssistantMessageSegment:
|
||||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||||
|
|
||||||
content: str | None = None
|
content: str = None
|
||||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
|
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
|
||||||
role: str = "assistant"
|
role: str = "assistant"
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
ret: dict[str, str | list[dict]] = {
|
ret = {
|
||||||
"role": self.role,
|
"role": self.role,
|
||||||
}
|
}
|
||||||
if self.content:
|
if self.content:
|
||||||
ret["content"] = self.content
|
ret["content"] = self.content
|
||||||
if self.tool_calls:
|
if self.tool_calls:
|
||||||
tool_calls_dict = [
|
ret["tool_calls"] = self.tool_calls
|
||||||
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
|
|
||||||
]
|
|
||||||
ret["tool_calls"] = tool_calls_dict
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@@ -120,14 +115,7 @@ class ProviderRequest:
|
|||||||
"""模型名称,为 None 时使用提供商的默认模型"""
|
"""模型名称,为 None 时使用提供商的默认模型"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||||
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
|
|
||||||
f"image_count={len(self.image_urls or [])}, "
|
|
||||||
f"func_tool={self.func_tool}, "
|
|
||||||
f"contexts={self._print_friendly_context()}, "
|
|
||||||
f"system_prompt={self.system_prompt}, "
|
|
||||||
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
@@ -217,17 +205,17 @@ class ProviderRequest:
|
|||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
role: str
|
role: str
|
||||||
"""角色, assistant, tool, err"""
|
"""角色, assistant, tool, err"""
|
||||||
result_chain: MessageChain | None = None
|
result_chain: MessageChain = None
|
||||||
"""返回的消息链"""
|
"""返回的消息链"""
|
||||||
tools_call_args: List[Dict[str, Any]] = field(default_factory=list)
|
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||||
"""工具调用参数"""
|
"""工具调用参数"""
|
||||||
tools_call_name: List[str] = field(default_factory=list)
|
tools_call_name: List[str] = field(default_factory=list)
|
||||||
"""工具调用名称"""
|
"""工具调用名称"""
|
||||||
tools_call_ids: List[str] = field(default_factory=list)
|
tools_call_ids: List[str] = field(default_factory=list)
|
||||||
"""工具调用 ID"""
|
"""工具调用 ID"""
|
||||||
|
|
||||||
raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None
|
raw_completion: ChatCompletion = None
|
||||||
_new_record: Dict[str, Any] | None = None
|
_new_record: Dict[str, any] = None
|
||||||
|
|
||||||
_completion_text: str = ""
|
_completion_text: str = ""
|
||||||
|
|
||||||
@@ -238,12 +226,12 @@ class LLMResponse:
|
|||||||
self,
|
self,
|
||||||
role: str,
|
role: str,
|
||||||
completion_text: str = "",
|
completion_text: str = "",
|
||||||
result_chain: MessageChain | None = None,
|
result_chain: MessageChain = None,
|
||||||
tools_call_args: List[Dict[str, Any]] | None = None,
|
tools_call_args: List[Dict[str, any]] = None,
|
||||||
tools_call_name: List[str] | None = None,
|
tools_call_name: List[str] = None,
|
||||||
tools_call_ids: List[str] | None = None,
|
tools_call_ids: List[str] = None,
|
||||||
raw_completion: ChatCompletion | None = None,
|
raw_completion: ChatCompletion = None,
|
||||||
_new_record: Dict[str, Any] | None = None,
|
_new_record: Dict[str, any] = None,
|
||||||
is_chunk: bool = False,
|
is_chunk: bool = False,
|
||||||
):
|
):
|
||||||
"""初始化 LLMResponse
|
"""初始化 LLMResponse
|
||||||
@@ -307,7 +295,6 @@ class LLMResponse:
|
|||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RerankResult:
|
class RerankResult:
|
||||||
index: int
|
index: int
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from typing import Dict, List, Awaitable, Callable, Any
|
from typing import Dict, List, Awaitable
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core import sp
|
from astrbot.core import sp
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list,
|
func_args: list,
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Callable[..., Awaitable[Any]],
|
handler: Awaitable,
|
||||||
) -> FuncTool:
|
) -> FuncTool:
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
@@ -132,7 +132,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list,
|
func_args: list,
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Callable[..., Awaitable[Any]],
|
handler: Awaitable,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""添加函数调用工具
|
"""添加函数调用工具
|
||||||
|
|
||||||
@@ -220,7 +220,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
cfg: dict,
|
cfg: dict,
|
||||||
event: asyncio.Event,
|
event: asyncio.Event,
|
||||||
ready_future: asyncio.Future | None = None,
|
ready_future: asyncio.Future = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -7,13 +7,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
|
|
||||||
from .entities import ProviderType
|
from .entities import ProviderType
|
||||||
from .provider import (
|
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||||
Provider,
|
|
||||||
STTProvider,
|
|
||||||
TTSProvider,
|
|
||||||
EmbeddingProvider,
|
|
||||||
RerankProvider,
|
|
||||||
)
|
|
||||||
from .register import llm_tools, provider_cls_map
|
from .register import llm_tools, provider_cls_map
|
||||||
from ..persona_mgr import PersonaManager
|
from ..persona_mgr import PersonaManager
|
||||||
|
|
||||||
@@ -44,12 +38,7 @@ class ProviderManager:
|
|||||||
"""加载的 Text To Speech Provider 的实例"""
|
"""加载的 Text To Speech Provider 的实例"""
|
||||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||||
"""加载的 Embedding Provider 的实例"""
|
"""加载的 Embedding Provider 的实例"""
|
||||||
self.rerank_provider_insts: List[RerankProvider] = []
|
self.inst_map: dict[str, Provider] = {}
|
||||||
"""加载的 Rerank Provider 的实例"""
|
|
||||||
self.inst_map: dict[
|
|
||||||
str,
|
|
||||||
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
|
||||||
] = {}
|
|
||||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||||
self.llm_tools = llm_tools
|
self.llm_tools = llm_tools
|
||||||
|
|
||||||
@@ -98,31 +87,19 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
# 不启用提供商会话隔离模式的情况
|
# 不启用提供商会话隔离模式的情况
|
||||||
|
self.curr_provider_inst = self.inst_map[provider_id]
|
||||||
prov = self.inst_map[provider_id]
|
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||||
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
|
||||||
prov, TTSProvider
|
|
||||||
):
|
|
||||||
self.curr_tts_provider_inst = prov
|
|
||||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||||
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||||
prov, STTProvider
|
|
||||||
):
|
|
||||||
self.curr_stt_provider_inst = prov
|
|
||||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||||
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||||
prov, Provider
|
|
||||||
):
|
|
||||||
self.curr_provider_inst = prov
|
|
||||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||||
|
|
||||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||||
"""根据提供商 ID 获取提供商实例"""
|
"""根据提供商 ID 获取提供商实例"""
|
||||||
return self.inst_map.get(provider_id)
|
return self.inst_map.get(provider_id)
|
||||||
|
|
||||||
def get_using_provider(
|
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
||||||
self, provider_type: ProviderType, umo=None
|
|
||||||
) -> Provider | STTProvider | TTSProvider | None:
|
|
||||||
"""获取正在使用的提供商实例。
|
"""获取正在使用的提供商实例。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -234,8 +211,6 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
case "dify":
|
case "dify":
|
||||||
from .sources.dify_source import ProviderDify as ProviderDify
|
from .sources.dify_source import ProviderDify as ProviderDify
|
||||||
case "coze":
|
|
||||||
from .sources.coze_source import ProviderCoze as ProviderCoze
|
|
||||||
case "dashscope":
|
case "dashscope":
|
||||||
from .sources.dashscope_source import (
|
from .sources.dashscope_source import (
|
||||||
ProviderDashscope as ProviderDashscope,
|
ProviderDashscope as ProviderDashscope,
|
||||||
@@ -328,14 +303,12 @@ class ProviderManager:
|
|||||||
provider_metadata = provider_cls_map[provider_config["type"]]
|
provider_metadata = provider_cls_map[provider_config["type"]]
|
||||||
try:
|
try:
|
||||||
# 按任务实例化提供商
|
# 按任务实例化提供商
|
||||||
cls_type = provider_metadata.cls_type
|
|
||||||
if not cls_type:
|
|
||||||
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
|
||||||
return
|
|
||||||
|
|
||||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||||
# STT 任务
|
# STT 任务
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
inst = provider_metadata.cls_type(
|
||||||
|
provider_config, self.provider_settings
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
@@ -354,7 +327,9 @@ class ProviderManager:
|
|||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||||
# TTS 任务
|
# TTS 任务
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
inst = provider_metadata.cls_type(
|
||||||
|
provider_config, self.provider_settings
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
@@ -370,7 +345,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||||
# 文本生成任务
|
# 文本生成任务
|
||||||
inst = cls_type(
|
inst = provider_metadata.cls_type(
|
||||||
provider_config,
|
provider_config,
|
||||||
self.provider_settings,
|
self.provider_settings,
|
||||||
self.selected_default_persona,
|
self.selected_default_persona,
|
||||||
@@ -391,16 +366,13 @@ class ProviderManager:
|
|||||||
if not self.curr_provider_inst:
|
if not self.curr_provider_inst:
|
||||||
self.curr_provider_inst = inst
|
self.curr_provider_inst = inst
|
||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
inst = provider_metadata.cls_type(
|
||||||
|
provider_config, self.provider_settings
|
||||||
|
)
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
self.embedding_provider_insts.append(inst)
|
self.embedding_provider_insts.append(inst)
|
||||||
elif provider_metadata.provider_type == ProviderType.RERANK:
|
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
|
||||||
if getattr(inst, "initialize", None):
|
|
||||||
await inst.initialize()
|
|
||||||
self.rerank_provider_insts.append(inst)
|
|
||||||
|
|
||||||
self.inst_map[provider_config["id"]] = inst
|
self.inst_map[provider_config["id"]] = inst
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -416,7 +388,6 @@ class ProviderManager:
|
|||||||
|
|
||||||
# 和配置文件保持同步
|
# 和配置文件保持同步
|
||||||
config_ids = [provider["id"] for provider in self.providers_config]
|
config_ids = [provider["id"] for provider in self.providers_config]
|
||||||
logger.debug(f"providers in user's config: {config_ids}")
|
|
||||||
for key in list(self.inst_map.keys()):
|
for key in list(self.inst_map.keys()):
|
||||||
if key not in config_ids:
|
if key not in config_ids:
|
||||||
await self.terminate_provider(key)
|
await self.terminate_provider(key)
|
||||||
@@ -455,17 +426,11 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.inst_map[provider_id] in self.provider_insts:
|
if self.inst_map[provider_id] in self.provider_insts:
|
||||||
prov_inst = self.inst_map[provider_id]
|
self.provider_insts.remove(self.inst_map[provider_id])
|
||||||
if isinstance(prov_inst, Provider):
|
|
||||||
self.provider_insts.remove(prov_inst)
|
|
||||||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||||||
prov_inst = self.inst_map[provider_id]
|
self.stt_provider_insts.remove(self.inst_map[provider_id])
|
||||||
if isinstance(prov_inst, STTProvider):
|
|
||||||
self.stt_provider_insts.remove(prov_inst)
|
|
||||||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||||||
prov_inst = self.inst_map[provider_id]
|
self.tts_provider_insts.remove(self.inst_map[provider_id])
|
||||||
if isinstance(prov_inst, TTSProvider):
|
|
||||||
self.tts_provider_insts.remove(prov_inst)
|
|
||||||
|
|
||||||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||||||
self.curr_provider_inst = None
|
self.curr_provider_inst = None
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class Provider(AbstractProvider):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_models(self) -> List[str]:
|
def get_models(self) -> List[str]:
|
||||||
"""获得支持的模型列表"""
|
"""获得支持的模型列表"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -1,314 +0,0 @@
|
|||||||
import json
|
|
||||||
import asyncio
|
|
||||||
import aiohttp
|
|
||||||
import io
|
|
||||||
from typing import Dict, List, Any, AsyncGenerator
|
|
||||||
from astrbot.core import logger
|
|
||||||
|
|
||||||
|
|
||||||
class CozeAPIClient:
|
|
||||||
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
|
||||||
self.api_key = api_key
|
|
||||||
self.api_base = api_base
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
async def _ensure_session(self):
|
|
||||||
"""确保HTTP session存在"""
|
|
||||||
if self.session is None:
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
ssl=False if self.api_base.startswith("http://") else True,
|
|
||||||
limit=100,
|
|
||||||
limit_per_host=30,
|
|
||||||
keepalive_timeout=30,
|
|
||||||
enable_cleanup_closed=True,
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(
|
|
||||||
total=120, # 默认超时时间
|
|
||||||
connect=30,
|
|
||||||
sock_read=120,
|
|
||||||
)
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Accept": "text/event-stream",
|
|
||||||
}
|
|
||||||
self.session = aiohttp.ClientSession(
|
|
||||||
headers=headers, timeout=timeout, connector=connector
|
|
||||||
)
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
async def upload_file(
|
|
||||||
self,
|
|
||||||
file_data: bytes,
|
|
||||||
) -> str:
|
|
||||||
"""上传文件到 Coze 并返回 file_id
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_data (bytes): 文件的二进制数据
|
|
||||||
Returns:
|
|
||||||
str: 上传成功后返回的 file_id
|
|
||||||
"""
|
|
||||||
session = await self._ensure_session()
|
|
||||||
url = f"{self.api_base}/v1/files/upload"
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_io = io.BytesIO(file_data)
|
|
||||||
async with session.post(
|
|
||||||
url,
|
|
||||||
data={
|
|
||||||
"file": file_io,
|
|
||||||
},
|
|
||||||
timeout=aiohttp.ClientTimeout(total=60),
|
|
||||||
) as response:
|
|
||||||
if response.status == 401:
|
|
||||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
|
||||||
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.debug(
|
|
||||||
f"文件上传响应状态: {response.status}, 内容: {response_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status != 200:
|
|
||||||
raise Exception(
|
|
||||||
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await response.json()
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise Exception(f"文件上传响应解析失败: {response_text}")
|
|
||||||
|
|
||||||
if result.get("code") != 0:
|
|
||||||
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
|
||||||
|
|
||||||
file_id = result["data"]["id"]
|
|
||||||
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
|
||||||
return file_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error("文件上传超时")
|
|
||||||
raise Exception("文件上传超时")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文件上传失败: {str(e)}")
|
|
||||||
raise Exception(f"文件上传失败: {str(e)}")
|
|
||||||
|
|
||||||
async def download_image(self, image_url: str) -> bytes:
|
|
||||||
"""下载图片并返回字节数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_url (str): 图片的URL
|
|
||||||
Returns:
|
|
||||||
bytes: 图片的二进制数据
|
|
||||||
"""
|
|
||||||
session = await self._ensure_session()
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with session.get(image_url) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise Exception(f"下载图片失败,状态码: {response.status}")
|
|
||||||
|
|
||||||
image_data = await response.read()
|
|
||||||
return image_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"下载图片失败 {image_url}: {str(e)}")
|
|
||||||
raise Exception(f"下载图片失败: {str(e)}")
|
|
||||||
|
|
||||||
async def chat_messages(
|
|
||||||
self,
|
|
||||||
bot_id: str,
|
|
||||||
user_id: str,
|
|
||||||
additional_messages: List[Dict] | None = None,
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
auto_save_history: bool = True,
|
|
||||||
stream: bool = True,
|
|
||||||
timeout: float = 120,
|
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
||||||
"""发送聊天消息并返回流式响应
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bot_id: Bot ID
|
|
||||||
user_id: 用户ID
|
|
||||||
additional_messages: 额外消息列表
|
|
||||||
conversation_id: 会话ID
|
|
||||||
auto_save_history: 是否自动保存历史
|
|
||||||
stream: 是否流式响应
|
|
||||||
timeout: 超时时间
|
|
||||||
"""
|
|
||||||
session = await self._ensure_session()
|
|
||||||
url = f"{self.api_base}/v3/chat"
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"bot_id": bot_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"stream": stream,
|
|
||||||
"auto_save_history": auto_save_history,
|
|
||||||
}
|
|
||||||
|
|
||||||
if additional_messages:
|
|
||||||
payload["additional_messages"] = additional_messages
|
|
||||||
|
|
||||||
params = {}
|
|
||||||
if conversation_id:
|
|
||||||
params["conversation_id"] = conversation_id
|
|
||||||
|
|
||||||
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with session.post(
|
|
||||||
url,
|
|
||||||
json=payload,
|
|
||||||
params=params,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
||||||
) as response:
|
|
||||||
if response.status == 401:
|
|
||||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
|
||||||
|
|
||||||
if response.status != 200:
|
|
||||||
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
|
||||||
|
|
||||||
# SSE
|
|
||||||
buffer = ""
|
|
||||||
event_type = None
|
|
||||||
event_data = None
|
|
||||||
|
|
||||||
async for chunk in response.content:
|
|
||||||
if chunk:
|
|
||||||
buffer += chunk.decode("utf-8", errors="ignore")
|
|
||||||
lines = buffer.split("\n")
|
|
||||||
buffer = lines[-1]
|
|
||||||
|
|
||||||
for line in lines[:-1]:
|
|
||||||
line = line.strip()
|
|
||||||
|
|
||||||
if not line:
|
|
||||||
if event_type and event_data:
|
|
||||||
yield {"event": event_type, "data": event_data}
|
|
||||||
event_type = None
|
|
||||||
event_data = None
|
|
||||||
elif line.startswith("event:"):
|
|
||||||
event_type = line[6:].strip()
|
|
||||||
elif line.startswith("data:"):
|
|
||||||
data_str = line[5:].strip()
|
|
||||||
if data_str and data_str != "[DONE]":
|
|
||||||
try:
|
|
||||||
event_data = json.loads(data_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
event_data = {"content": data_str}
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Coze API 流式请求失败: {str(e)}")
|
|
||||||
|
|
||||||
async def clear_context(self, conversation_id: str):
|
|
||||||
"""清空会话上下文
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: 会话ID
|
|
||||||
Returns:
|
|
||||||
dict: API响应结果
|
|
||||||
"""
|
|
||||||
session = await self._ensure_session()
|
|
||||||
url = f"{self.api_base}/v3/conversation/message/clear_context"
|
|
||||||
payload = {"conversation_id": conversation_id}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with session.post(url, json=payload) as response:
|
|
||||||
response_text = await response.text()
|
|
||||||
|
|
||||||
if response.status == 401:
|
|
||||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
|
||||||
|
|
||||||
if response.status != 200:
|
|
||||||
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
return json.loads(response_text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise Exception("Coze API 返回非JSON格式")
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise Exception("Coze API 请求超时")
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
raise Exception(f"Coze API 请求失败: {str(e)}")
|
|
||||||
|
|
||||||
async def get_message_list(
|
|
||||||
self,
|
|
||||||
conversation_id: str,
|
|
||||||
order: str = "desc",
|
|
||||||
limit: int = 10,
|
|
||||||
offset: int = 0,
|
|
||||||
):
|
|
||||||
"""获取消息列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: 会话ID
|
|
||||||
order: 排序方式 (asc/desc)
|
|
||||||
limit: 限制数量
|
|
||||||
offset: 偏移量
|
|
||||||
Returns:
|
|
||||||
dict: API响应结果
|
|
||||||
"""
|
|
||||||
session = await self._ensure_session()
|
|
||||||
url = f"{self.api_base}/v3/conversation/message/list"
|
|
||||||
params = {
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"order": order,
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with session.get(url, params=params) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
return await response.json()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取Coze消息列表失败: {str(e)}")
|
|
||||||
raise Exception(f"获取Coze消息列表失败: {str(e)}")
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""关闭会话"""
|
|
||||||
if self.session:
|
|
||||||
await self.session.close()
|
|
||||||
self.session = None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
async def test_coze_api_client():
|
|
||||||
api_key = os.getenv("COZE_API_KEY", "")
|
|
||||||
bot_id = os.getenv("COZE_BOT_ID", "")
|
|
||||||
client = CozeAPIClient(api_key=api_key)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open("README.md", "rb") as f:
|
|
||||||
file_data = f.read()
|
|
||||||
file_id = await client.upload_file(file_data)
|
|
||||||
print(f"Uploaded file_id: {file_id}")
|
|
||||||
async for event in client.chat_messages(
|
|
||||||
bot_id=bot_id,
|
|
||||||
user_id="test_user",
|
|
||||||
additional_messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": json.dumps(
|
|
||||||
[
|
|
||||||
{"type": "text", "text": "这是什么"},
|
|
||||||
{"type": "file", "file_id": file_id},
|
|
||||||
],
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
"content_type": "object_string",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
stream=True,
|
|
||||||
):
|
|
||||||
print(f"Event: {event}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
asyncio.run(test_coze_api_client())
|
|
||||||
@@ -1,635 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
from typing import AsyncGenerator, Dict
|
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
|
||||||
import astrbot.core.message.components as Comp
|
|
||||||
from astrbot.api.provider import Provider
|
|
||||||
from astrbot import logger
|
|
||||||
from astrbot.core.provider.entities import LLMResponse
|
|
||||||
from ..register import register_provider_adapter
|
|
||||||
from .coze_api_client import CozeAPIClient
|
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
|
||||||
class ProviderCoze(Provider):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
provider_config,
|
|
||||||
provider_settings,
|
|
||||||
default_persona=None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
provider_config,
|
|
||||||
provider_settings,
|
|
||||||
default_persona,
|
|
||||||
)
|
|
||||||
self.api_key = provider_config.get("coze_api_key", "")
|
|
||||||
if not self.api_key:
|
|
||||||
raise Exception("Coze API Key 不能为空。")
|
|
||||||
self.bot_id = provider_config.get("bot_id", "")
|
|
||||||
if not self.bot_id:
|
|
||||||
raise Exception("Coze Bot ID 不能为空。")
|
|
||||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
|
||||||
|
|
||||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
|
||||||
("http://", "https://")
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.timeout = provider_config.get("timeout", 120)
|
|
||||||
if isinstance(self.timeout, str):
|
|
||||||
self.timeout = int(self.timeout)
|
|
||||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
|
||||||
self.conversation_ids: Dict[str, str] = {}
|
|
||||||
self.file_id_cache: Dict[str, Dict[str, str]] = {}
|
|
||||||
|
|
||||||
# 创建 API 客户端
|
|
||||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
|
||||||
|
|
||||||
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
|
||||||
"""生成统一的缓存键
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 图片数据或路径
|
|
||||||
is_base64: 是否是 base64 数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 缓存键
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_base64 and data.startswith("data:image/"):
|
|
||||||
try:
|
|
||||||
header, encoded = data.split(",", 1)
|
|
||||||
image_bytes = base64.b64decode(encoded)
|
|
||||||
cache_key = hashlib.md5(image_bytes).hexdigest()
|
|
||||||
return cache_key
|
|
||||||
except Exception:
|
|
||||||
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
|
||||||
return cache_key
|
|
||||||
else:
|
|
||||||
if data.startswith(("http://", "https://")):
|
|
||||||
# URL图片,使用URL作为缓存键
|
|
||||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
|
||||||
return cache_key
|
|
||||||
else:
|
|
||||||
clean_path = (
|
|
||||||
data.split("_")[0]
|
|
||||||
if "_" in data and len(data.split("_")) >= 3
|
|
||||||
else data
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(clean_path):
|
|
||||||
with open(clean_path, "rb") as f:
|
|
||||||
file_content = f.read()
|
|
||||||
cache_key = hashlib.md5(file_content).hexdigest()
|
|
||||||
return cache_key
|
|
||||||
else:
|
|
||||||
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
|
||||||
return cache_key
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
|
||||||
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
|
||||||
return cache_key
|
|
||||||
|
|
||||||
async def _upload_file(
|
|
||||||
self,
|
|
||||||
file_data: bytes,
|
|
||||||
session_id: str | None = None,
|
|
||||||
cache_key: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""上传文件到 Coze 并返回 file_id"""
|
|
||||||
# 使用 API 客户端上传文件
|
|
||||||
file_id = await self.api_client.upload_file(file_data)
|
|
||||||
|
|
||||||
# 缓存 file_id
|
|
||||||
if session_id and cache_key:
|
|
||||||
if session_id not in self.file_id_cache:
|
|
||||||
self.file_id_cache[session_id] = {}
|
|
||||||
self.file_id_cache[session_id][cache_key] = file_id
|
|
||||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
|
||||||
|
|
||||||
return file_id
|
|
||||||
|
|
||||||
async def _download_and_upload_image(
|
|
||||||
self, image_url: str, session_id: str | None = None
|
|
||||||
) -> str:
|
|
||||||
"""下载图片并上传到 Coze,返回 file_id"""
|
|
||||||
# 计算哈希实现缓存
|
|
||||||
cache_key = self._generate_cache_key(image_url) if session_id else None
|
|
||||||
|
|
||||||
if session_id and cache_key:
|
|
||||||
if session_id not in self.file_id_cache:
|
|
||||||
self.file_id_cache[session_id] = {}
|
|
||||||
|
|
||||||
if cache_key in self.file_id_cache[session_id]:
|
|
||||||
file_id = self.file_id_cache[session_id][cache_key]
|
|
||||||
return file_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_data = await self.api_client.download_image(image_url)
|
|
||||||
|
|
||||||
file_id = await self._upload_file(image_data, session_id, cache_key)
|
|
||||||
|
|
||||||
if session_id and cache_key:
|
|
||||||
self.file_id_cache[session_id][cache_key] = file_id
|
|
||||||
|
|
||||||
return file_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理图片失败 {image_url}: {str(e)}")
|
|
||||||
raise Exception(f"处理图片失败: {str(e)}")
|
|
||||||
|
|
||||||
async def _process_context_images(
|
|
||||||
self, content: str | list, session_id: str
|
|
||||||
) -> str:
|
|
||||||
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
|
|
||||||
processed_content = []
|
|
||||||
if session_id not in self.file_id_cache:
|
|
||||||
self.file_id_cache[session_id] = {}
|
|
||||||
|
|
||||||
for item in content:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
processed_content.append(item)
|
|
||||||
continue
|
|
||||||
if item.get("type") == "text":
|
|
||||||
processed_content.append(item)
|
|
||||||
elif item.get("type") == "image_url":
|
|
||||||
# 处理图片逻辑
|
|
||||||
if "file_id" in item:
|
|
||||||
# 已经有 file_id
|
|
||||||
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
|
||||||
processed_content.append(item)
|
|
||||||
else:
|
|
||||||
# 获取图片数据
|
|
||||||
image_data = ""
|
|
||||||
if "image_url" in item and isinstance(item["image_url"], dict):
|
|
||||||
image_data = item["image_url"].get("url", "")
|
|
||||||
elif "data" in item:
|
|
||||||
image_data = item.get("data", "")
|
|
||||||
elif "url" in item:
|
|
||||||
image_data = item.get("url", "")
|
|
||||||
|
|
||||||
if not image_data:
|
|
||||||
continue
|
|
||||||
# 计算哈希用于缓存
|
|
||||||
cache_key = self._generate_cache_key(
|
|
||||||
image_data, is_base64=image_data.startswith("data:image/")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查缓存
|
|
||||||
if cache_key in self.file_id_cache[session_id]:
|
|
||||||
file_id = self.file_id_cache[session_id][cache_key]
|
|
||||||
processed_content.append(
|
|
||||||
{"type": "image", "file_id": file_id}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 上传图片并缓存
|
|
||||||
if image_data.startswith("data:image/"):
|
|
||||||
# base64 处理
|
|
||||||
_, encoded = image_data.split(",", 1)
|
|
||||||
image_bytes = base64.b64decode(encoded)
|
|
||||||
file_id = await self._upload_file(
|
|
||||||
image_bytes,
|
|
||||||
session_id,
|
|
||||||
cache_key,
|
|
||||||
)
|
|
||||||
elif image_data.startswith(("http://", "https://")):
|
|
||||||
# URL 图片
|
|
||||||
file_id = await self._download_and_upload_image(
|
|
||||||
image_data, session_id
|
|
||||||
)
|
|
||||||
# 为URL图片也添加缓存
|
|
||||||
self.file_id_cache[session_id][cache_key] = file_id
|
|
||||||
elif os.path.exists(image_data):
|
|
||||||
# 本地文件
|
|
||||||
with open(image_data, "rb") as f:
|
|
||||||
image_bytes = f.read()
|
|
||||||
file_id = await self._upload_file(
|
|
||||||
image_bytes,
|
|
||||||
session_id,
|
|
||||||
cache_key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"无法处理的图片格式: {image_data[:50]}..."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
processed_content.append(
|
|
||||||
{"type": "image", "file_id": file_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = json.dumps(processed_content, ensure_ascii=False)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理上下文图片失败: {str(e)}")
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
else:
|
|
||||||
return json.dumps(content, ensure_ascii=False)
|
|
||||||
|
|
||||||
async def text_chat(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
session_id=None,
|
|
||||||
image_urls=None,
|
|
||||||
func_tool=None,
|
|
||||||
contexts=None,
|
|
||||||
system_prompt=None,
|
|
||||||
tool_calls_result=None,
|
|
||||||
model=None,
|
|
||||||
**kwargs,
|
|
||||||
) -> LLMResponse:
|
|
||||||
"""文本对话, 内部使用流式接口实现非流式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (str): 用户提示词
|
|
||||||
session_id (str): 会话ID
|
|
||||||
image_urls (List[str]): 图片URL列表
|
|
||||||
func_tool (FuncCall): 函数调用工具(不支持)
|
|
||||||
contexts (List): 上下文列表
|
|
||||||
system_prompt (str): 系统提示语
|
|
||||||
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
|
||||||
model (str): 模型名称(不支持)
|
|
||||||
Returns:
|
|
||||||
LLMResponse: LLM响应对象
|
|
||||||
"""
|
|
||||||
accumulated_content = ""
|
|
||||||
final_response = None
|
|
||||||
|
|
||||||
async for llm_response in self.text_chat_stream(
|
|
||||||
prompt=prompt,
|
|
||||||
session_id=session_id,
|
|
||||||
image_urls=image_urls,
|
|
||||||
func_tool=func_tool,
|
|
||||||
contexts=contexts,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tool_calls_result=tool_calls_result,
|
|
||||||
model=model,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if llm_response.is_chunk:
|
|
||||||
if llm_response.completion_text:
|
|
||||||
accumulated_content += llm_response.completion_text
|
|
||||||
else:
|
|
||||||
final_response = llm_response
|
|
||||||
|
|
||||||
if final_response:
|
|
||||||
return final_response
|
|
||||||
|
|
||||||
if accumulated_content:
|
|
||||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
|
||||||
return LLMResponse(role="assistant", result_chain=chain)
|
|
||||||
else:
|
|
||||||
return LLMResponse(role="assistant", completion_text="")
|
|
||||||
|
|
||||||
async def text_chat_stream(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
session_id=None,
|
|
||||||
image_urls=None,
|
|
||||||
func_tool=None,
|
|
||||||
contexts=None,
|
|
||||||
system_prompt=None,
|
|
||||||
tool_calls_result=None,
|
|
||||||
model=None,
|
|
||||||
**kwargs,
|
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
|
||||||
"""流式对话接口"""
|
|
||||||
# 用户ID参数(参考文档, 可以自定义)
|
|
||||||
user_id = session_id or kwargs.get("user", "default_user")
|
|
||||||
|
|
||||||
# 获取或创建会话ID
|
|
||||||
conversation_id = self.conversation_ids.get(user_id)
|
|
||||||
|
|
||||||
# 构建消息
|
|
||||||
additional_messages = []
|
|
||||||
|
|
||||||
if system_prompt:
|
|
||||||
if not self.auto_save_history or not conversation_id:
|
|
||||||
additional_messages.append(
|
|
||||||
{"role": "system", "content": system_prompt, "content_type": "text"}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.auto_save_history and contexts:
|
|
||||||
# 如果关闭了自动保存历史,传入上下文
|
|
||||||
for ctx in contexts:
|
|
||||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
|
||||||
content = ctx["content"]
|
|
||||||
content_type = ctx.get("content_type", "text")
|
|
||||||
|
|
||||||
# 处理可能包含图片的上下文
|
|
||||||
if (
|
|
||||||
content_type == "object_string"
|
|
||||||
or (isinstance(content, str) and content.startswith("["))
|
|
||||||
or (
|
|
||||||
isinstance(content, list)
|
|
||||||
and any(
|
|
||||||
isinstance(item, dict)
|
|
||||||
and item.get("type") == "image_url"
|
|
||||||
for item in content
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
processed_content = await self._process_context_images(
|
|
||||||
content, user_id
|
|
||||||
)
|
|
||||||
additional_messages.append(
|
|
||||||
{
|
|
||||||
"role": ctx["role"],
|
|
||||||
"content": processed_content,
|
|
||||||
"content_type": "object_string",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 纯文本
|
|
||||||
additional_messages.append(
|
|
||||||
{
|
|
||||||
"role": ctx["role"],
|
|
||||||
"content": (
|
|
||||||
content
|
|
||||||
if isinstance(content, str)
|
|
||||||
else json.dumps(content, ensure_ascii=False)
|
|
||||||
),
|
|
||||||
"content_type": "text",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
|
||||||
|
|
||||||
if prompt or image_urls:
|
|
||||||
if image_urls:
|
|
||||||
# 多模态
|
|
||||||
object_string_content = []
|
|
||||||
if prompt:
|
|
||||||
object_string_content.append({"type": "text", "text": prompt})
|
|
||||||
|
|
||||||
for url in image_urls:
|
|
||||||
try:
|
|
||||||
if url.startswith(("http://", "https://")):
|
|
||||||
# 网络图片
|
|
||||||
file_id = await self._download_and_upload_image(
|
|
||||||
url, user_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 本地文件或 base64
|
|
||||||
if url.startswith("data:image/"):
|
|
||||||
# base64
|
|
||||||
_, encoded = url.split(",", 1)
|
|
||||||
image_data = base64.b64decode(encoded)
|
|
||||||
cache_key = self._generate_cache_key(
|
|
||||||
url, is_base64=True
|
|
||||||
)
|
|
||||||
file_id = await self._upload_file(
|
|
||||||
image_data, user_id, cache_key
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 本地文件
|
|
||||||
if os.path.exists(url):
|
|
||||||
with open(url, "rb") as f:
|
|
||||||
image_data = f.read()
|
|
||||||
# 用文件路径和修改时间来缓存
|
|
||||||
file_stat = os.stat(url)
|
|
||||||
cache_key = self._generate_cache_key(
|
|
||||||
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
|
||||||
is_base64=False,
|
|
||||||
)
|
|
||||||
file_id = await self._upload_file(
|
|
||||||
image_data, user_id, cache_key
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"图片文件不存在: {url}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
object_string_content.append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"file_id": file_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理图片失败 {url}: {str(e)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if object_string_content:
|
|
||||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
|
||||||
additional_messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": content,
|
|
||||||
"content_type": "object_string",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 纯文本
|
|
||||||
if prompt:
|
|
||||||
additional_messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
"content_type": "text",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
accumulated_content = ""
|
|
||||||
message_started = False
|
|
||||||
|
|
||||||
async for chunk in self.api_client.chat_messages(
|
|
||||||
bot_id=self.bot_id,
|
|
||||||
user_id=user_id,
|
|
||||||
additional_messages=additional_messages,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
auto_save_history=self.auto_save_history,
|
|
||||||
stream=True,
|
|
||||||
timeout=self.timeout,
|
|
||||||
):
|
|
||||||
event_type = chunk.get("event")
|
|
||||||
data = chunk.get("data", {})
|
|
||||||
|
|
||||||
if event_type == "conversation.chat.created":
|
|
||||||
if isinstance(data, dict) and "conversation_id" in data:
|
|
||||||
self.conversation_ids[user_id] = data["conversation_id"]
|
|
||||||
|
|
||||||
elif event_type == "conversation.message.delta":
|
|
||||||
if isinstance(data, dict):
|
|
||||||
content = data.get("content", "")
|
|
||||||
if not content and "delta" in data:
|
|
||||||
content = data["delta"].get("content", "")
|
|
||||||
if not content and "text" in data:
|
|
||||||
content = data.get("text", "")
|
|
||||||
|
|
||||||
if content:
|
|
||||||
message_started = True
|
|
||||||
accumulated_content += content
|
|
||||||
yield LLMResponse(
|
|
||||||
role="assistant",
|
|
||||||
completion_text=content,
|
|
||||||
is_chunk=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event_type == "conversation.message.completed":
|
|
||||||
if isinstance(data, dict):
|
|
||||||
msg_type = data.get("type")
|
|
||||||
if msg_type == "answer" and data.get("role") == "assistant":
|
|
||||||
final_content = data.get("content", "")
|
|
||||||
if not accumulated_content and final_content:
|
|
||||||
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
|
||||||
yield LLMResponse(
|
|
||||||
role="assistant",
|
|
||||||
result_chain=chain,
|
|
||||||
is_chunk=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event_type == "conversation.chat.completed":
|
|
||||||
if accumulated_content:
|
|
||||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
|
||||||
yield LLMResponse(
|
|
||||||
role="assistant",
|
|
||||||
result_chain=chain,
|
|
||||||
is_chunk=False,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
elif event_type == "done":
|
|
||||||
break
|
|
||||||
|
|
||||||
elif event_type == "error":
|
|
||||||
error_msg = (
|
|
||||||
data.get("message", "未知错误")
|
|
||||||
if isinstance(data, dict)
|
|
||||||
else str(data)
|
|
||||||
)
|
|
||||||
logger.error(f"Coze 流式响应错误: {error_msg}")
|
|
||||||
yield LLMResponse(
|
|
||||||
role="err",
|
|
||||||
completion_text=f"Coze 错误: {error_msg}",
|
|
||||||
is_chunk=False,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
if not message_started and not accumulated_content:
|
|
||||||
yield LLMResponse(
|
|
||||||
role="assistant",
|
|
||||||
completion_text="LLM 未响应任何内容。",
|
|
||||||
is_chunk=False,
|
|
||||||
)
|
|
||||||
elif message_started and accumulated_content:
|
|
||||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
|
||||||
yield LLMResponse(
|
|
||||||
role="assistant",
|
|
||||||
result_chain=chain,
|
|
||||||
is_chunk=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Coze 流式请求失败: {str(e)}")
|
|
||||||
yield LLMResponse(
|
|
||||||
role="err",
|
|
||||||
completion_text=f"Coze 流式请求失败: {str(e)}",
|
|
||||||
is_chunk=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def forget(self, session_id: str):
|
|
||||||
"""清空指定会话的上下文"""
|
|
||||||
user_id = session_id
|
|
||||||
conversation_id = self.conversation_ids.get(user_id)
|
|
||||||
|
|
||||||
if user_id in self.file_id_cache:
|
|
||||||
self.file_id_cache.pop(user_id, None)
|
|
||||||
|
|
||||||
if not conversation_id:
|
|
||||||
return True
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.api_client.clear_context(conversation_id)
|
|
||||||
|
|
||||||
if "code" in response and response["code"] == 0:
|
|
||||||
self.conversation_ids.pop(user_id, None)
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清空 Coze 会话失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_current_key(self):
|
|
||||||
"""获取当前API Key"""
|
|
||||||
return self.api_key
|
|
||||||
|
|
||||||
async def set_key(self, key: str):
|
|
||||||
"""设置新的API Key"""
|
|
||||||
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
|
||||||
|
|
||||||
async def get_models(self):
|
|
||||||
"""获取可用模型列表"""
|
|
||||||
return [f"bot_{self.bot_id}"]
|
|
||||||
|
|
||||||
def get_model(self):
|
|
||||||
"""获取当前模型"""
|
|
||||||
return f"bot_{self.bot_id}"
|
|
||||||
|
|
||||||
def set_model(self, model: str):
|
|
||||||
"""设置模型(在Coze中是Bot ID)"""
|
|
||||||
if model.startswith("bot_"):
|
|
||||||
self.bot_id = model[4:]
|
|
||||||
else:
|
|
||||||
self.bot_id = model
|
|
||||||
|
|
||||||
async def get_human_readable_context(
|
|
||||||
self, session_id: str, page: int = 1, page_size: int = 10
|
|
||||||
):
|
|
||||||
"""获取人类可读的上下文历史"""
|
|
||||||
user_id = session_id
|
|
||||||
conversation_id = self.conversation_ids.get(user_id)
|
|
||||||
|
|
||||||
if not conversation_id:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await self.api_client.get_message_list(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
order="desc",
|
|
||||||
limit=page_size,
|
|
||||||
offset=(page - 1) * page_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.get("code") != 0:
|
|
||||||
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
messages = data.get("data", {}).get("messages", [])
|
|
||||||
|
|
||||||
readable_history = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "unknown")
|
|
||||||
content = msg.get("content", "")
|
|
||||||
msg_type = msg.get("type", "")
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
readable_history.append(f"用户: {content}")
|
|
||||||
elif role == "assistant" and msg_type == "answer":
|
|
||||||
readable_history.append(f"助手: {content}")
|
|
||||||
|
|
||||||
return readable_history
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def terminate(self):
|
|
||||||
"""清理资源"""
|
|
||||||
await self.api_client.close()
|
|
||||||
@@ -98,7 +98,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|||||||
|
|
||||||
# FishAudio的reference_id通常是32位十六进制字符串
|
# FishAudio的reference_id通常是32位十六进制字符串
|
||||||
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
|
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
|
||||||
pattern = r"^[a-fA-F0-9]{32}$"
|
pattern = r'^[a-fA-F0-9]{32}$'
|
||||||
return bool(re.match(pattern, reference_id.strip()))
|
return bool(re.match(pattern, reference_id.strip()))
|
||||||
|
|
||||||
async def _generate_request(self, text: str) -> dict:
|
async def _generate_request(self, text: str) -> dict:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from astrbot import logger
|
|||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
@@ -61,7 +61,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.api_keys: list = provider_config.get("key", [])
|
self.api_keys: list = provider_config.get("key", [])
|
||||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||||
|
|
||||||
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
||||||
@@ -96,9 +96,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
|
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
|
||||||
"""处理API错误,返回是否需要重试"""
|
"""处理API错误,返回是否需要重试"""
|
||||||
if e.message is None:
|
|
||||||
e.message = ""
|
|
||||||
|
|
||||||
if e.code == 429 or "API key not valid" in e.message:
|
if e.code == 429 or "API key not valid" in e.message:
|
||||||
keys.remove(self.chosen_api_key)
|
keys.remove(self.chosen_api_key)
|
||||||
if len(keys) > 0:
|
if len(keys) > 0:
|
||||||
@@ -122,7 +119,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
async def _prepare_query_config(
|
async def _prepare_query_config(
|
||||||
self,
|
self,
|
||||||
payloads: dict,
|
payloads: dict,
|
||||||
tools: Optional[ToolSet] = None,
|
tools: Optional[FuncCall] = None,
|
||||||
system_instruction: Optional[str] = None,
|
system_instruction: Optional[str] = None,
|
||||||
modalities: Optional[list[str]] = None,
|
modalities: Optional[list[str]] = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
@@ -324,15 +321,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_content_parts(
|
def _process_content_parts(
|
||||||
candidate: types.Candidate, llm_response: LLMResponse
|
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||||
) -> MessageChain:
|
) -> MessageChain:
|
||||||
"""处理内容部分并构建消息链"""
|
"""处理内容部分并构建消息链"""
|
||||||
if not candidate.content:
|
finish_reason = result.candidates[0].finish_reason
|
||||||
logger.warning(f"收到的 candidate.content 为空: {candidate}")
|
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||||
raise Exception("API 返回的 candidate.content 为空。")
|
|
||||||
|
|
||||||
finish_reason = candidate.finish_reason
|
|
||||||
result_parts: list[types.Part] | None = candidate.content.parts
|
|
||||||
|
|
||||||
if finish_reason == types.FinishReason.SAFETY:
|
if finish_reason == types.FinishReason.SAFETY:
|
||||||
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
|
||||||
@@ -350,28 +343,22 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
raise Exception("模型生成内容违反 Gemini 平台政策")
|
raise Exception("模型生成内容违反 Gemini 平台政策")
|
||||||
|
|
||||||
if not result_parts:
|
if not result_parts:
|
||||||
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
|
logger.debug(result.candidates)
|
||||||
raise Exception("API 返回的 candidate.content.parts 为空。")
|
raise Exception("API 返回的内容为空。")
|
||||||
|
|
||||||
chain = []
|
chain = []
|
||||||
part: types.Part
|
part: types.Part
|
||||||
|
|
||||||
# 暂时这样Fallback
|
# 暂时这样Fallback
|
||||||
if all(
|
if all(
|
||||||
part.inline_data
|
part.inline_data and part.inline_data.mime_type.startswith("image/")
|
||||||
and part.inline_data.mime_type
|
|
||||||
and part.inline_data.mime_type.startswith("image/")
|
|
||||||
for part in result_parts
|
for part in result_parts
|
||||||
):
|
):
|
||||||
chain.append(Comp.Plain("这是图片"))
|
chain.append(Comp.Plain("这是图片"))
|
||||||
for part in result_parts:
|
for part in result_parts:
|
||||||
if part.text:
|
if part.text:
|
||||||
chain.append(Comp.Plain(part.text))
|
chain.append(Comp.Plain(part.text))
|
||||||
elif (
|
elif part.function_call:
|
||||||
part.function_call
|
|
||||||
and part.function_call.name is not None
|
|
||||||
and part.function_call.args is not None
|
|
||||||
):
|
|
||||||
llm_response.role = "tool"
|
llm_response.role = "tool"
|
||||||
llm_response.tools_call_name.append(part.function_call.name)
|
llm_response.tools_call_name.append(part.function_call.name)
|
||||||
llm_response.tools_call_args.append(part.function_call.args)
|
llm_response.tools_call_args.append(part.function_call.args)
|
||||||
@@ -379,16 +366,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
llm_response.tools_call_ids.append(
|
llm_response.tools_call_ids.append(
|
||||||
part.function_call.id or part.function_call.name
|
part.function_call.id or part.function_call.name
|
||||||
)
|
)
|
||||||
elif (
|
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
||||||
part.inline_data
|
|
||||||
and part.inline_data.mime_type
|
|
||||||
and part.inline_data.mime_type.startswith("image/")
|
|
||||||
and part.inline_data.data
|
|
||||||
):
|
|
||||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||||
return MessageChain(chain=chain)
|
return MessageChain(chain=chain)
|
||||||
|
|
||||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||||
"""非流式请求 Gemini API"""
|
"""非流式请求 Gemini API"""
|
||||||
system_instruction = next(
|
system_instruction = next(
|
||||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||||
@@ -414,10 +396,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.candidates:
|
|
||||||
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
|
|
||||||
raise Exception("请求失败, 返回的 candidates 为空。")
|
|
||||||
|
|
||||||
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||||
if temperature > 2:
|
if temperature > 2:
|
||||||
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||||
@@ -430,8 +408,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
break
|
break
|
||||||
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
if e.message is None:
|
|
||||||
e.message = ""
|
|
||||||
if "Developer instruction is not enabled" in e.message:
|
if "Developer instruction is not enabled" in e.message:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||||
@@ -456,13 +432,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
|
|
||||||
llm_response = LLMResponse("assistant")
|
llm_response = LLMResponse("assistant")
|
||||||
llm_response.raw_completion = result
|
llm_response.raw_completion = result
|
||||||
llm_response.result_chain = self._process_content_parts(
|
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||||
result.candidates[0], llm_response
|
|
||||||
)
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
async def _query_stream(
|
async def _query_stream(
|
||||||
self, payloads: dict, tools: ToolSet | None
|
self, payloads: dict, tools: FuncCall
|
||||||
) -> AsyncGenerator[LLMResponse, None]:
|
) -> AsyncGenerator[LLMResponse, None]:
|
||||||
"""流式请求 Gemini API"""
|
"""流式请求 Gemini API"""
|
||||||
system_instruction = next(
|
system_instruction = next(
|
||||||
@@ -485,8 +459,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
if e.message is None:
|
|
||||||
e.message = ""
|
|
||||||
if "Developer instruction is not enabled" in e.message:
|
if "Developer instruction is not enabled" in e.message:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||||
@@ -506,20 +478,13 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
async for chunk in result:
|
async for chunk in result:
|
||||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||||
|
|
||||||
if not chunk.candidates:
|
|
||||||
logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}")
|
|
||||||
continue
|
|
||||||
if not chunk.candidates[0].content:
|
|
||||||
logger.warning(f"收到的 chunk 中 content 为空: {chunk}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if chunk.candidates[0].content.parts and any(
|
if chunk.candidates[0].content.parts and any(
|
||||||
part.function_call for part in chunk.candidates[0].content.parts
|
part.function_call for part in chunk.candidates[0].content.parts
|
||||||
):
|
):
|
||||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||||
llm_response.raw_completion = chunk
|
llm_response.raw_completion = chunk
|
||||||
llm_response.result_chain = self._process_content_parts(
|
llm_response.result_chain = self._process_content_parts(
|
||||||
chunk.candidates[0], llm_response
|
chunk, llm_response
|
||||||
)
|
)
|
||||||
yield llm_response
|
yield llm_response
|
||||||
return
|
return
|
||||||
@@ -535,7 +500,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
final_response = LLMResponse("assistant", is_chunk=False)
|
final_response = LLMResponse("assistant", is_chunk=False)
|
||||||
final_response.raw_completion = chunk
|
final_response.raw_completion = chunk
|
||||||
final_response.result_chain = self._process_content_parts(
|
final_response.result_chain = self._process_content_parts(
|
||||||
chunk.candidates[0], final_response
|
chunk, final_response
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -601,8 +566,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
|
|
||||||
raise Exception("请求失败。")
|
|
||||||
|
|
||||||
async def text_chat_stream(
|
async def text_chat_stream(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
@@ -658,9 +621,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
return [
|
return [
|
||||||
m.name.replace("models/", "")
|
m.name.replace("models/", "")
|
||||||
for m in models
|
for m in models
|
||||||
if m.supported_actions
|
if "generateContent" in m.supported_actions
|
||||||
and "generateContent" in m.supported_actions
|
|
||||||
and m.name
|
|
||||||
]
|
]
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
raise Exception(f"获取模型列表失败: {e.message}")
|
raise Exception(f"获取模型列表失败: {e.message}")
|
||||||
@@ -675,7 +636,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
self.chosen_api_key = key
|
self.chosen_api_key = key
|
||||||
self._init_client()
|
self._init_client()
|
||||||
|
|
||||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
async def assemble_context(self, text: str, image_urls: list[str] = None):
|
||||||
"""
|
"""
|
||||||
组装上下文。
|
组装上下文。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -99,15 +99,12 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
for key in to_del:
|
for key in to_del:
|
||||||
del payloads[key]
|
del payloads[key]
|
||||||
|
|
||||||
# 读取并合并 custom_extra_body 配置
|
model = payloads.get("model", "")
|
||||||
custom_extra_body = self.provider_config.get("custom_extra_body", {})
|
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
|
||||||
if isinstance(custom_extra_body, dict):
|
if "qwen3" in model.lower():
|
||||||
extra_body.update(custom_extra_body)
|
extra_body["enable_thinking"] = False
|
||||||
|
|
||||||
model = payloads.get("model", "").lower()
|
|
||||||
|
|
||||||
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
||||||
if model == "deepseek-reasoner" and "tools" in payloads:
|
elif model == "deepseek-reasoner" and "tools" in payloads:
|
||||||
del payloads["tools"]
|
del payloads["tools"]
|
||||||
|
|
||||||
completion = await self.client.chat.completions.create(
|
completion = await self.client.chat.completions.create(
|
||||||
@@ -140,12 +137,6 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
|
|
||||||
# 不在默认参数中的参数放在 extra_body 中
|
# 不在默认参数中的参数放在 extra_body 中
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
|
|
||||||
# 读取并合并 custom_extra_body 配置
|
|
||||||
custom_extra_body = self.provider_config.get("custom_extra_body", {})
|
|
||||||
if isinstance(custom_extra_body, dict):
|
|
||||||
extra_body.update(custom_extra_body)
|
|
||||||
|
|
||||||
to_del = []
|
to_del = []
|
||||||
for key in payloads.keys():
|
for key in payloads.keys():
|
||||||
if key not in self.default_params:
|
if key not in self.default_params:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from astrbot import logger
|
|
||||||
from ..provider import RerankProvider
|
from ..provider import RerankProvider
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from ..entities import ProviderType, RerankResult
|
from ..entities import ProviderType, RerankResult
|
||||||
@@ -45,11 +44,6 @@ class VLLMRerankProvider(RerankProvider):
|
|||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
results = response_data.get("results", [])
|
results = response_data.get("results", [])
|
||||||
|
|
||||||
if not results:
|
|
||||||
logger.warning(
|
|
||||||
f"Rerank API 返回了空的列表数据。原始响应: {response_data}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
RerankResult(
|
RerankResult(
|
||||||
index=result["index"],
|
index=result["index"],
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
# This file was originally created to adapt to glm-4v-flash, which only supports one image in the context.
|
from astrbot import logger
|
||||||
# It is no longer specifically adapted to Zhipu's models. To ensure compatibility, this
|
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from .openai_source import ProviderOpenAIOfficial
|
from .openai_source import ProviderOpenAIOfficial
|
||||||
|
|
||||||
|
|
||||||
@register_provider_adapter("zhipu_chat_completion", "智谱 Chat Completion 提供商适配器")
|
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
|
||||||
class ProviderZhipu(ProviderOpenAIOfficial):
|
class ProviderZhipu(ProviderOpenAIOfficial):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -19,3 +19,63 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
|||||||
provider_settings,
|
provider_settings,
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def text_chat(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
session_id: str = None,
|
||||||
|
image_urls: List[str] = None,
|
||||||
|
func_tool: FuncCall = None,
|
||||||
|
contexts=None,
|
||||||
|
system_prompt=None,
|
||||||
|
model=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
if contexts is None:
|
||||||
|
contexts = []
|
||||||
|
new_record = await self.assemble_context(prompt, image_urls)
|
||||||
|
context_query = []
|
||||||
|
|
||||||
|
context_query = [*contexts, new_record]
|
||||||
|
|
||||||
|
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||||
|
model = model or self.get_model()
|
||||||
|
# glm-4v-flash 只支持一张图片
|
||||||
|
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
|
||||||
|
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||||
|
logger.debug(context_query)
|
||||||
|
new_context_query_ = []
|
||||||
|
for i in range(0, len(context_query) - 1, 2):
|
||||||
|
if isinstance(context_query[i].get("content", ""), list):
|
||||||
|
continue
|
||||||
|
new_context_query_.append(context_query[i])
|
||||||
|
new_context_query_.append(context_query[i + 1])
|
||||||
|
new_context_query_.append(context_query[-1]) # 保留最后一条记录
|
||||||
|
context_query = new_context_query_
|
||||||
|
logger.debug(context_query)
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
payloads = {"messages": context_query, **model_cfgs}
|
||||||
|
try:
|
||||||
|
llm_response = await self._query(payloads, func_tool)
|
||||||
|
return llm_response
|
||||||
|
except Exception as e:
|
||||||
|
if "maximum context length" in str(e):
|
||||||
|
retry_cnt = 10
|
||||||
|
while retry_cnt > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.pop_record(session_id)
|
||||||
|
llm_response = await self._query(payloads, func_tool)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
if "maximum context length" in str(e):
|
||||||
|
retry_cnt -= 1
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|||||||
@@ -27,16 +27,14 @@ class Star(CommandParserMixin):
|
|||||||
star_map[cls.__module__].star_cls_type = cls
|
star_map[cls.__module__].star_cls_type = cls
|
||||||
star_map[cls.__module__].module_path = cls.__module__
|
star_map[cls.__module__].module_path = cls.__module__
|
||||||
|
|
||||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
@staticmethod
|
||||||
|
async def text_to_image(text: str, return_url=True) -> str:
|
||||||
"""将文本转换为图片"""
|
"""将文本转换为图片"""
|
||||||
return await html_renderer.render_t2i(
|
return await html_renderer.render_t2i(text, return_url=return_url)
|
||||||
text,
|
|
||||||
return_url=return_url,
|
|
||||||
template_name=self.context._config.get("t2i_active_template"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
async def html_render(
|
async def html_render(
|
||||||
self, tmpl: str, data: dict, return_url=True, options: dict | None = None
|
tmpl: str, data: dict, return_url=True, options: dict | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""渲染 HTML"""
|
"""渲染 HTML"""
|
||||||
return await html_renderer.render_custom_template(
|
return await html_renderer.render_custom_template(
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from astrbot.core.provider.provider import (
|
|||||||
TTSProvider,
|
TTSProvider,
|
||||||
STTProvider,
|
STTProvider,
|
||||||
EmbeddingProvider,
|
EmbeddingProvider,
|
||||||
RerankProvider,
|
|
||||||
)
|
)
|
||||||
from astrbot.core.provider.entities import ProviderType
|
from astrbot.core.provider.entities import ProviderType
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
@@ -24,7 +23,7 @@ from .star import star_registry, StarMetadata, star_map
|
|||||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||||
from .filter.command import CommandFilter
|
from .filter.command import CommandFilter
|
||||||
from .filter.regex import RegexFilter
|
from .filter.regex import RegexFilter
|
||||||
from typing import Awaitable, Any, Callable
|
from typing import Awaitable
|
||||||
from astrbot.core.conversation_mgr import ConversationManager
|
from astrbot.core.conversation_mgr import ConversationManager
|
||||||
from astrbot.core.star.filter.platform_adapter_type import (
|
from astrbot.core.star.filter.platform_adapter_type import (
|
||||||
PlatformAdapterType,
|
PlatformAdapterType,
|
||||||
@@ -104,14 +103,9 @@ class Context:
|
|||||||
"""
|
"""
|
||||||
self.provider_manager.provider_insts.append(provider)
|
self.provider_manager.provider_insts.append(provider)
|
||||||
|
|
||||||
def get_provider_by_id(
|
def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||||
self, provider_id: str
|
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||||
) -> (
|
return self.provider_manager.inst_map.get(provider_id)
|
||||||
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
|
|
||||||
):
|
|
||||||
"""通过 ID 获取对应的 LLM Provider。"""
|
|
||||||
prov = self.provider_manager.inst_map.get(provider_id)
|
|
||||||
return prov
|
|
||||||
|
|
||||||
def get_all_providers(self) -> List[Provider]:
|
def get_all_providers(self) -> List[Provider]:
|
||||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||||
@@ -136,43 +130,34 @@ class Context:
|
|||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
prov = self.provider_manager.get_using_provider(
|
return self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.CHAT_COMPLETION,
|
provider_type=ProviderType.CHAT_COMPLETION,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
if prov and not isinstance(prov, Provider):
|
|
||||||
raise ValueError("返回的 Provider 不是 Provider 类型")
|
|
||||||
return prov
|
|
||||||
|
|
||||||
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
|
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
|
||||||
"""
|
"""
|
||||||
获取当前使用的用于 TTS 任务的 Provider。
|
获取当前使用的用于 TTS 任务的 Provider。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
prov = self.provider_manager.get_using_provider(
|
return self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
if prov and not isinstance(prov, TTSProvider):
|
|
||||||
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
|
|
||||||
return prov
|
|
||||||
|
|
||||||
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
|
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
|
||||||
"""
|
"""
|
||||||
获取当前使用的用于 STT 任务的 Provider。
|
获取当前使用的用于 STT 任务的 Provider。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
prov = self.provider_manager.get_using_provider(
|
return self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
if prov and not isinstance(prov, STTProvider):
|
|
||||||
raise ValueError("返回的 Provider 不是 STTProvider 类型")
|
|
||||||
return prov
|
|
||||||
|
|
||||||
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
||||||
"""获取 AstrBot 的配置。"""
|
"""获取 AstrBot 的配置。"""
|
||||||
@@ -260,11 +245,7 @@ class Context:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def register_llm_tool(
|
def register_llm_tool(
|
||||||
self,
|
self, name: str, func_args: list, desc: str, func_obj: Awaitable
|
||||||
name: str,
|
|
||||||
func_args: list,
|
|
||||||
desc: str,
|
|
||||||
func_obj: Callable[..., Awaitable[Any]],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为函数调用(function-calling / tools-use)添加工具。
|
为函数调用(function-calling / tools-use)添加工具。
|
||||||
@@ -286,7 +267,9 @@ class Context:
|
|||||||
desc=desc,
|
desc=desc,
|
||||||
)
|
)
|
||||||
star_handlers_registry.append(md)
|
star_handlers_registry.append(md)
|
||||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
|
self.provider_manager.llm_tools.add_func(
|
||||||
|
name, func_args, desc, func_obj, func_obj
|
||||||
|
)
|
||||||
|
|
||||||
def unregister_llm_tool(self, name: str) -> None:
|
def unregister_llm_tool(self, name: str) -> None:
|
||||||
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
||||||
@@ -298,7 +281,7 @@ class Context:
|
|||||||
command_name: str,
|
command_name: str,
|
||||||
desc: str,
|
desc: str,
|
||||||
priority: int,
|
priority: int,
|
||||||
awaitable: Callable[..., Awaitable[Any]],
|
awaitable: Awaitable,
|
||||||
use_regex=False,
|
use_regex=False,
|
||||||
ignore_prefix=False,
|
ignore_prefix=False,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
import inspect
|
import inspect
|
||||||
import types
|
|
||||||
import typing
|
|
||||||
from typing import List, Any, Type, Dict
|
from typing import List, Any, Type, Dict
|
||||||
from . import HandlerFilter
|
from . import HandlerFilter
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
@@ -16,18 +14,6 @@ class GreedyStr(str):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def unwrap_optional(annotation) -> tuple:
|
|
||||||
"""去掉 Optional[T] / Union[T, None] / T|None,返回 T"""
|
|
||||||
args = typing.get_args(annotation)
|
|
||||||
non_none_args = [a for a in args if a is not type(None)]
|
|
||||||
if len(non_none_args) == 1:
|
|
||||||
return (non_none_args[0],)
|
|
||||||
elif len(non_none_args) > 1:
|
|
||||||
return tuple(non_none_args)
|
|
||||||
else:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
|
|
||||||
# 标准指令受到 wake_prefix 的制约。
|
# 标准指令受到 wake_prefix 的制约。
|
||||||
class CommandFilter(HandlerFilter):
|
class CommandFilter(HandlerFilter):
|
||||||
"""标准指令过滤器"""
|
"""标准指令过滤器"""
|
||||||
@@ -46,16 +32,11 @@ class CommandFilter(HandlerFilter):
|
|||||||
self.init_handler_md(handler_md)
|
self.init_handler_md(handler_md)
|
||||||
self.custom_filter_list: List[CustomFilter] = []
|
self.custom_filter_list: List[CustomFilter] = []
|
||||||
|
|
||||||
# Cache for complete command names list
|
|
||||||
self._cmpl_cmd_names: list | None = None
|
|
||||||
|
|
||||||
def print_types(self):
|
def print_types(self):
|
||||||
result = ""
|
result = ""
|
||||||
for k, v in self.handler_params.items():
|
for k, v in self.handler_params.items():
|
||||||
if isinstance(v, type):
|
if isinstance(v, type):
|
||||||
result += f"{k}({v.__name__}),"
|
result += f"{k}({v.__name__}),"
|
||||||
elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union:
|
|
||||||
result += f"{k}({v}),"
|
|
||||||
else:
|
else:
|
||||||
result += f"{k}({type(v).__name__})={v},"
|
result += f"{k}({type(v).__name__})={v},"
|
||||||
result = result.rstrip(",")
|
result = result.rstrip(",")
|
||||||
@@ -111,8 +92,7 @@ class CommandFilter(HandlerFilter):
|
|||||||
# 没有 GreedyStr 的情况
|
# 没有 GreedyStr 的情况
|
||||||
if i >= len(params):
|
if i >= len(params):
|
||||||
if (
|
if (
|
||||||
isinstance(param_type_or_default_val, (Type, types.UnionType))
|
isinstance(param_type_or_default_val, Type)
|
||||||
or typing.get_origin(param_type_or_default_val) is typing.Union
|
|
||||||
or param_type_or_default_val is inspect.Parameter.empty
|
or param_type_or_default_val is inspect.Parameter.empty
|
||||||
):
|
):
|
||||||
# 是类型
|
# 是类型
|
||||||
@@ -149,42 +129,13 @@ class CommandFilter(HandlerFilter):
|
|||||||
elif isinstance(param_type_or_default_val, float):
|
elif isinstance(param_type_or_default_val, float):
|
||||||
result[param_name] = float(params[i])
|
result[param_name] = float(params[i])
|
||||||
else:
|
else:
|
||||||
origin = typing.get_origin(param_type_or_default_val)
|
result[param_name] = param_type_or_default_val(params[i])
|
||||||
if origin in (typing.Union, types.UnionType):
|
|
||||||
# 注解是联合类型
|
|
||||||
# NOTE: 目前没有处理联合类型嵌套相关的注解写法
|
|
||||||
nn_types = unwrap_optional(param_type_or_default_val)
|
|
||||||
if len(nn_types) == 1:
|
|
||||||
# 只有一个非 NoneType 类型
|
|
||||||
result[param_name] = nn_types[0](params[i])
|
|
||||||
else:
|
|
||||||
# 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。
|
|
||||||
# NOTE: 目前还没有做类型校验
|
|
||||||
result[param_name] = params[i]
|
|
||||||
else:
|
|
||||||
result[param_name] = param_type_or_default_val(params[i])
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"
|
f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_complete_command_names(self):
|
|
||||||
if self._cmpl_cmd_names is not None:
|
|
||||||
return self._cmpl_cmd_names
|
|
||||||
self._cmpl_cmd_names = [
|
|
||||||
f"{parent} {cmd}" if parent else cmd
|
|
||||||
for cmd in [self.command_name] + list(self.alias)
|
|
||||||
for parent in self.parent_command_names or [""]
|
|
||||||
]
|
|
||||||
return self._cmpl_cmd_names
|
|
||||||
|
|
||||||
def equals(self, message_str: str) -> bool:
|
|
||||||
for full_cmd in self.get_complete_command_names():
|
|
||||||
if message_str == full_cmd:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||||
if not event.is_at_or_wake_command:
|
if not event.is_at_or_wake_command:
|
||||||
return False
|
return False
|
||||||
@@ -194,11 +145,18 @@ class CommandFilter(HandlerFilter):
|
|||||||
|
|
||||||
# 检查是否以指令开头
|
# 检查是否以指令开头
|
||||||
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
||||||
|
candidates = [self.command_name] + list(self.alias)
|
||||||
ok = False
|
ok = False
|
||||||
for full_cmd in self.get_complete_command_names():
|
for candidate in candidates:
|
||||||
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
|
for parent_command_name in self.parent_command_names:
|
||||||
ok = True
|
if parent_command_name:
|
||||||
message_str = message_str[len(full_cmd) :].strip()
|
_full = f"{parent_command_name} {candidate}"
|
||||||
|
else:
|
||||||
|
_full = candidate
|
||||||
|
if message_str.startswith(f"{_full} ") or message_str == _full:
|
||||||
|
message_str = message_str[len(_full) :].strip()
|
||||||
|
ok = True
|
||||||
|
break
|
||||||
if not ok:
|
if not ok:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group_name: str,
|
group_name: str,
|
||||||
alias: set | None = None,
|
alias: set = None,
|
||||||
parent_group: CommandGroupFilter | None = None,
|
parent_group: CommandGroupFilter = None,
|
||||||
):
|
):
|
||||||
self.group_name = group_name
|
self.group_name = group_name
|
||||||
self.alias = alias if alias else set()
|
self.alias = alias if alias else set()
|
||||||
@@ -22,9 +22,6 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
self.custom_filter_list: List[CustomFilter] = []
|
self.custom_filter_list: List[CustomFilter] = []
|
||||||
self.parent_group = parent_group
|
self.parent_group = parent_group
|
||||||
|
|
||||||
# Cache for complete command names list
|
|
||||||
self._cmpl_cmd_names: list | None = None
|
|
||||||
|
|
||||||
def add_sub_command_filter(
|
def add_sub_command_filter(
|
||||||
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
|
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
|
||||||
):
|
):
|
||||||
@@ -37,9 +34,6 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
"""遍历父节点获取完整的指令名。
|
"""遍历父节点获取完整的指令名。
|
||||||
|
|
||||||
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
|
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
|
||||||
if self._cmpl_cmd_names is not None:
|
|
||||||
return self._cmpl_cmd_names
|
|
||||||
|
|
||||||
parent_cmd_names = (
|
parent_cmd_names = (
|
||||||
self.parent_group.get_complete_command_names() if self.parent_group else []
|
self.parent_group.get_complete_command_names() if self.parent_group else []
|
||||||
)
|
)
|
||||||
@@ -53,7 +47,6 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
for parent_cmd_name in parent_cmd_names:
|
for parent_cmd_name in parent_cmd_names:
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
result.append(parent_cmd_name + " " + candidate)
|
result.append(parent_cmd_name + " " + candidate)
|
||||||
self._cmpl_cmd_names = result
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 以树的形式打印出来
|
# 以树的形式打印出来
|
||||||
@@ -61,8 +54,8 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
self,
|
self,
|
||||||
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
event: AstrMessageEvent | None = None,
|
event: AstrMessageEvent = None,
|
||||||
cfg: AstrBotConfig | None = None,
|
cfg: AstrBotConfig = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
result = ""
|
result = ""
|
||||||
for sub_filter in sub_command_filters:
|
for sub_filter in sub_command_filters:
|
||||||
@@ -104,12 +97,6 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def startswith(self, message_str: str) -> bool:
|
|
||||||
return message_str.startswith(tuple(self.get_complete_command_names()))
|
|
||||||
|
|
||||||
def equals(self, message_str: str) -> bool:
|
|
||||||
return message_str in self.get_complete_command_names()
|
|
||||||
|
|
||||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||||
if not event.is_at_or_wake_command:
|
if not event.is_at_or_wake_command:
|
||||||
return False
|
return False
|
||||||
@@ -118,14 +105,18 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
if not self.custom_filter_ok(event, cfg):
|
if not self.custom_filter_ok(event, cfg):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.equals(event.message_str.strip()):
|
complete_command_names = self.get_complete_command_names()
|
||||||
|
if event.message_str.strip() in complete_command_names:
|
||||||
tree = (
|
tree = (
|
||||||
self.group_name
|
self.group_name
|
||||||
+ "\n"
|
+ "\n"
|
||||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
||||||
|
+ tree
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.startswith(event.message_str)
|
# complete_command_names = [name + " " for name in complete_command_names]
|
||||||
|
# return event.message_str.startswith(tuple(complete_command_names))
|
||||||
|
return False
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import enum
|
|||||||
from . import HandlerFilter
|
from . import HandlerFilter
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
class PlatformAdapterType(enum.Flag):
|
class PlatformAdapterType(enum.Flag):
|
||||||
@@ -17,8 +18,6 @@ class PlatformAdapterType(enum.Flag):
|
|||||||
KOOK = enum.auto()
|
KOOK = enum.auto()
|
||||||
VOCECHAT = enum.auto()
|
VOCECHAT = enum.auto()
|
||||||
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
||||||
SATORI = enum.auto()
|
|
||||||
MISSKEY = enum.auto()
|
|
||||||
ALL = (
|
ALL = (
|
||||||
AIOCQHTTP
|
AIOCQHTTP
|
||||||
| QQOFFICIAL
|
| QQOFFICIAL
|
||||||
@@ -32,8 +31,6 @@ class PlatformAdapterType(enum.Flag):
|
|||||||
| KOOK
|
| KOOK
|
||||||
| VOCECHAT
|
| VOCECHAT
|
||||||
| WEIXIN_OFFICIAL_ACCOUNT
|
| WEIXIN_OFFICIAL_ACCOUNT
|
||||||
| SATORI
|
|
||||||
| MISSKEY
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -50,20 +47,15 @@ ADAPTER_NAME_2_TYPE = {
|
|||||||
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
|
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
|
||||||
"vocechat": PlatformAdapterType.VOCECHAT,
|
"vocechat": PlatformAdapterType.VOCECHAT,
|
||||||
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
||||||
"satori": PlatformAdapterType.SATORI,
|
|
||||||
"misskey": PlatformAdapterType.MISSKEY,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||||
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
|
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
|
||||||
if isinstance(platform_adapter_type_or_str, str):
|
self.type_or_str = platform_adapter_type_or_str
|
||||||
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
|
|
||||||
else:
|
|
||||||
self.platform_type = platform_adapter_type_or_str
|
|
||||||
|
|
||||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||||
adapter_name = event.get_platform_name()
|
adapter_name = event.get_platform_name()
|
||||||
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
|
if adapter_name in ADAPTER_NAME_2_TYPE:
|
||||||
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
|
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from .star_handler import (
|
|||||||
register_permission_type,
|
register_permission_type,
|
||||||
register_custom_filter,
|
register_custom_filter,
|
||||||
register_on_astrbot_loaded,
|
register_on_astrbot_loaded,
|
||||||
register_on_platform_loaded,
|
|
||||||
register_on_llm_request,
|
register_on_llm_request,
|
||||||
register_on_llm_response,
|
register_on_llm_response,
|
||||||
register_llm_tool,
|
register_llm_tool,
|
||||||
@@ -27,7 +26,6 @@ __all__ = [
|
|||||||
"register_permission_type",
|
"register_permission_type",
|
||||||
"register_custom_filter",
|
"register_custom_filter",
|
||||||
"register_on_astrbot_loaded",
|
"register_on_astrbot_loaded",
|
||||||
"register_on_platform_loaded",
|
|
||||||
"register_on_llm_request",
|
"register_on_llm_request",
|
||||||
"register_on_llm_response",
|
"register_on_llm_response",
|
||||||
"register_llm_tool",
|
"register_llm_tool",
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ from astrbot.core.star import StarMetadata, star_map
|
|||||||
_warned_register_star = False
|
_warned_register_star = False
|
||||||
|
|
||||||
|
|
||||||
def register_star(
|
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
||||||
name: str, author: str, desc: str, version: str, repo: str | None = None
|
|
||||||
):
|
|
||||||
"""注册一个插件(Star)。
|
"""注册一个插件(Star)。
|
||||||
|
|
||||||
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from ..filter.platform_adapter_type import (
|
|||||||
from ..filter.permission import PermissionTypeFilter, PermissionType
|
from ..filter.permission import PermissionTypeFilter, PermissionType
|
||||||
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
||||||
from ..filter.regex import RegexFilter
|
from ..filter.regex import RegexFilter
|
||||||
from typing import Awaitable, Any, Callable
|
from typing import Awaitable
|
||||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
from astrbot.core.agent.agent import Agent
|
from astrbot.core.agent.agent import Agent
|
||||||
@@ -20,19 +20,15 @@ from astrbot.core.agent.tool import FunctionTool
|
|||||||
from astrbot.core.agent.handoff import HandoffTool
|
from astrbot.core.agent.handoff import HandoffTool
|
||||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
from astrbot.core import logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
|
def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||||
"""获取 Handler 的全名"""
|
"""获取 Handler 的全名"""
|
||||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||||
|
|
||||||
|
|
||||||
def get_handler_or_create(
|
def get_handler_or_create(
|
||||||
handler: Callable[..., Awaitable[Any]],
|
handler: Awaitable, event_type: EventType, dont_add=False, **kwargs
|
||||||
event_type: EventType,
|
|
||||||
dont_add=False,
|
|
||||||
**kwargs,
|
|
||||||
) -> StarHandlerMetadata:
|
) -> StarHandlerMetadata:
|
||||||
"""获取 Handler 或者创建一个新的 Handler"""
|
"""获取 Handler 或者创建一个新的 Handler"""
|
||||||
handler_full_name = get_handler_full_name(handler)
|
handler_full_name = get_handler_full_name(handler)
|
||||||
@@ -63,35 +59,22 @@ def get_handler_or_create(
|
|||||||
|
|
||||||
|
|
||||||
def register_command(
|
def register_command(
|
||||||
command_name: str | None = None,
|
command_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
||||||
sub_command: str | None = None,
|
|
||||||
alias: set | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
"""注册一个 Command."""
|
"""注册一个 Command."""
|
||||||
new_command = None
|
new_command = None
|
||||||
add_to_event_filters = False
|
add_to_event_filters = False
|
||||||
if isinstance(command_name, RegisteringCommandable):
|
if isinstance(command_name, RegisteringCommandable):
|
||||||
# 子指令
|
# 子指令
|
||||||
if sub_command is not None:
|
parent_command_names = command_name.parent_group.get_complete_command_names()
|
||||||
parent_command_names = (
|
new_command = CommandFilter(
|
||||||
command_name.parent_group.get_complete_command_names()
|
sub_command, alias, None, parent_command_names=parent_command_names
|
||||||
)
|
)
|
||||||
new_command = CommandFilter(
|
command_name.parent_group.add_sub_command_filter(new_command)
|
||||||
sub_command, alias, None, parent_command_names=parent_command_names
|
|
||||||
)
|
|
||||||
command_name.parent_group.add_sub_command_filter(new_command)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 裸指令
|
# 裸指令
|
||||||
if command_name is None:
|
new_command = CommandFilter(command_name, alias, None)
|
||||||
logger.warning("注册裸指令时未提供 command_name 参数。")
|
add_to_event_filters = True
|
||||||
else:
|
|
||||||
new_command = CommandFilter(command_name, alias, None)
|
|
||||||
add_to_event_filters = True
|
|
||||||
|
|
||||||
def decorator(awaitable):
|
def decorator(awaitable):
|
||||||
if not add_to_event_filters:
|
if not add_to_event_filters:
|
||||||
@@ -101,9 +84,8 @@ def register_command(
|
|||||||
handler_md = get_handler_or_create(
|
handler_md = get_handler_or_create(
|
||||||
awaitable, EventType.AdapterMessageEvent, **kwargs
|
awaitable, EventType.AdapterMessageEvent, **kwargs
|
||||||
)
|
)
|
||||||
if new_command:
|
new_command.init_handler_md(handler_md)
|
||||||
new_command.init_handler_md(handler_md)
|
handler_md.event_filters.append(new_command)
|
||||||
handler_md.event_filters.append(new_command)
|
|
||||||
return awaitable
|
return awaitable
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -181,38 +163,26 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def register_command_group(
|
def register_command_group(
|
||||||
command_group_name: str | None = None,
|
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
||||||
sub_command: str | None = None,
|
|
||||||
alias: set | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
"""注册一个 CommandGroup"""
|
"""注册一个 CommandGroup"""
|
||||||
new_group = None
|
new_group = None
|
||||||
if isinstance(command_group_name, RegisteringCommandable):
|
if isinstance(command_group_name, RegisteringCommandable):
|
||||||
# 子指令组
|
# 子指令组
|
||||||
if sub_command is None:
|
new_group = CommandGroupFilter(
|
||||||
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
|
sub_command, alias, parent_group=command_group_name.parent_group
|
||||||
else:
|
)
|
||||||
new_group = CommandGroupFilter(
|
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||||
sub_command, alias, parent_group=command_group_name.parent_group
|
|
||||||
)
|
|
||||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
|
||||||
else:
|
else:
|
||||||
# 根指令组
|
# 根指令组
|
||||||
if command_group_name is None:
|
new_group = CommandGroupFilter(command_group_name, alias)
|
||||||
logger.warning("根指令组的名称未指定")
|
|
||||||
else:
|
|
||||||
new_group = CommandGroupFilter(command_group_name, alias)
|
|
||||||
|
|
||||||
def decorator(obj):
|
def decorator(obj):
|
||||||
if new_group:
|
# 根指令组
|
||||||
handler_md = get_handler_or_create(
|
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
||||||
obj, EventType.AdapterMessageEvent, **kwargs
|
handler_md.event_filters.append(new_group)
|
||||||
)
|
|
||||||
handler_md.event_filters.append(new_group)
|
|
||||||
|
|
||||||
return RegisteringCommandable(new_group)
|
return RegisteringCommandable(new_group)
|
||||||
raise ValueError("注册指令组失败。")
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@@ -220,11 +190,9 @@ def register_command_group(
|
|||||||
class RegisteringCommandable:
|
class RegisteringCommandable:
|
||||||
"""用于指令组级联注册"""
|
"""用于指令组级联注册"""
|
||||||
|
|
||||||
group: Callable[..., Callable[..., "RegisteringCommandable"]] = (
|
group: CommandGroupFilter = register_command_group
|
||||||
register_command_group
|
command: CommandFilter = register_command
|
||||||
)
|
custom_filter = register_custom_filter
|
||||||
command: Callable[..., Callable[..., None]] = register_command
|
|
||||||
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
|
|
||||||
|
|
||||||
def __init__(self, parent_group: CommandGroupFilter):
|
def __init__(self, parent_group: CommandGroupFilter):
|
||||||
self.parent_group = parent_group
|
self.parent_group = parent_group
|
||||||
@@ -299,18 +267,6 @@ def register_on_astrbot_loaded(**kwargs):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def register_on_platform_loaded(**kwargs):
|
|
||||||
"""
|
|
||||||
当平台加载完成时
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(awaitable):
|
|
||||||
_ = get_handler_or_create(awaitable, EventType.OnPlatformLoadedEvent, **kwargs)
|
|
||||||
return awaitable
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_on_llm_request(**kwargs):
|
def register_on_llm_request(**kwargs):
|
||||||
"""当有 LLM 请求时的事件
|
"""当有 LLM 请求时的事件
|
||||||
|
|
||||||
@@ -355,7 +311,7 @@ def register_on_llm_response(**kwargs):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def register_llm_tool(name: str | None = None, **kwargs):
|
def register_llm_tool(name: str = None, **kwargs):
|
||||||
"""为函数调用(function-calling / tools-use)添加工具。
|
"""为函数调用(function-calling / tools-use)添加工具。
|
||||||
|
|
||||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||||
@@ -393,10 +349,9 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
|||||||
if kwargs.get("registering_agent"):
|
if kwargs.get("registering_agent"):
|
||||||
registering_agent = kwargs["registering_agent"]
|
registering_agent = kwargs["registering_agent"]
|
||||||
|
|
||||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
def decorator(awaitable: Awaitable):
|
||||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||||
func_doc = awaitable.__doc__ or ""
|
docstring = docstring_parser.parse(awaitable.__doc__)
|
||||||
docstring = docstring_parser.parse(func_doc)
|
|
||||||
args = []
|
args = []
|
||||||
for arg in docstring.params:
|
for arg in docstring.params:
|
||||||
if arg.type_name not in SUPPORTED_TYPES:
|
if arg.type_name not in SUPPORTED_TYPES:
|
||||||
@@ -412,18 +367,18 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
|||||||
)
|
)
|
||||||
# print(llm_tool_name, registering_agent)
|
# print(llm_tool_name, registering_agent)
|
||||||
if not registering_agent:
|
if not registering_agent:
|
||||||
doc_desc = docstring.description.strip() if docstring.description else ""
|
|
||||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||||
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
|
llm_tools.add_func(
|
||||||
|
llm_tool_name, args, docstring.description.strip(), md.handler
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(registering_agent, RegisteringAgent)
|
assert isinstance(registering_agent, RegisteringAgent)
|
||||||
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
||||||
if registering_agent._agent.tools is None:
|
if registering_agent._agent.tools is None:
|
||||||
registering_agent._agent.tools = []
|
registering_agent._agent.tools = []
|
||||||
|
registering_agent._agent.tools.append(llm_tools.spec_to_func(
|
||||||
desc = docstring.description.strip() if docstring.description else ""
|
llm_tool_name, args, docstring.description.strip(), awaitable
|
||||||
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
|
))
|
||||||
registering_agent._agent.tools.append(tool)
|
|
||||||
|
|
||||||
return awaitable
|
return awaitable
|
||||||
|
|
||||||
@@ -444,8 +399,8 @@ class RegisteringAgent:
|
|||||||
def register_agent(
|
def register_agent(
|
||||||
name: str,
|
name: str,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
tools: list[str | FunctionTool] | None = None,
|
tools: list[str | FunctionTool] = None,
|
||||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
|
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
|
||||||
):
|
):
|
||||||
"""注册一个 Agent
|
"""注册一个 Agent
|
||||||
|
|
||||||
@@ -457,7 +412,7 @@ def register_agent(
|
|||||||
"""
|
"""
|
||||||
tools_ = tools or []
|
tools_ = tools or []
|
||||||
|
|
||||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
def decorator(awaitable: Awaitable):
|
||||||
AstrAgent = Agent[AstrAgentContext]
|
AstrAgent = Agent[AstrAgentContext]
|
||||||
agent = AstrAgent(
|
agent = AstrAgent(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -466,7 +421,7 @@ def register_agent(
|
|||||||
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||||
)
|
)
|
||||||
handoff_tool = HandoffTool(agent=agent)
|
handoff_tool = HandoffTool(agent=agent)
|
||||||
handoff_tool.handler = awaitable
|
handoff_tool.handler=awaitable
|
||||||
llm_tools.func_list.append(handoff_tool)
|
llm_tools.func_list.append(handoff_tool)
|
||||||
return RegisteringAgent(agent)
|
return RegisteringAgent(agent)
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ class SessionServiceManager:
|
|||||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||||
"""检查是否应该处理LLM请求
|
"""检查是否应该处理LLM请求
|
||||||
|
|||||||
@@ -84,10 +84,7 @@ class SessionPluginManager:
|
|||||||
session_config["disabled_plugins"] = disabled_plugins
|
session_config["disabled_plugins"] = disabled_plugins
|
||||||
session_plugin_config[session_id] = session_config
|
session_plugin_config[session_id] = session_config
|
||||||
sp.put(
|
sp.put(
|
||||||
"session_plugin_config",
|
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
|
||||||
session_plugin_config,
|
|
||||||
scope="umo",
|
|
||||||
scope_id=session_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -140,9 +137,6 @@ class SessionPluginManager:
|
|||||||
filtered_handlers.append(handler)
|
filtered_handlers.append(handler)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if plugin.name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查插件是否在当前会话中启用
|
# 检查插件是否在当前会话中启用
|
||||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||||
session_id, plugin.name
|
session_id, plugin.name
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic
|
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||||
from .filter import HandlerFilter
|
from .filter import HandlerFilter
|
||||||
from .star import star_map
|
from .star import star_map
|
||||||
|
|
||||||
@@ -34,33 +34,26 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
) -> List[StarHandlerMetadata]:
|
) -> List[StarHandlerMetadata]:
|
||||||
handlers = []
|
handlers = []
|
||||||
for handler in self._handlers:
|
for handler in self._handlers:
|
||||||
# 过滤事件类型
|
|
||||||
if handler.event_type != event_type:
|
if handler.event_type != event_type:
|
||||||
continue
|
continue
|
||||||
# 过滤启用状态
|
|
||||||
if only_activated:
|
if only_activated:
|
||||||
plugin = star_map.get(handler.handler_module_path)
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
if not (plugin and plugin.activated):
|
if not (plugin and plugin.activated):
|
||||||
continue
|
continue
|
||||||
# 过滤插件白名单
|
|
||||||
if plugins_name is not None and plugins_name != ["*"]:
|
if plugins_name is not None and plugins_name != ["*"]:
|
||||||
plugin = star_map.get(handler.handler_module_path)
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
continue
|
continue
|
||||||
if (
|
if (
|
||||||
plugin.name not in plugins_name
|
plugin.name not in plugins_name
|
||||||
and event_type
|
and event_type != EventType.OnAstrBotLoadedEvent
|
||||||
not in (
|
|
||||||
EventType.OnAstrBotLoadedEvent,
|
|
||||||
EventType.OnPlatformLoadedEvent,
|
|
||||||
)
|
|
||||||
and not plugin.reserved
|
and not plugin.reserved
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
handlers.append(handler)
|
handlers.append(handler)
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None:
|
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||||
return self.star_handlers_map.get(full_name, None)
|
return self.star_handlers_map.get(full_name, None)
|
||||||
|
|
||||||
def get_handlers_by_module_name(
|
def get_handlers_by_module_name(
|
||||||
@@ -87,7 +80,7 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
return len(self._handlers)
|
return len(self._handlers)
|
||||||
|
|
||||||
|
|
||||||
star_handlers_registry = StarHandlerRegistry() # type: ignore
|
star_handlers_registry = StarHandlerRegistry()
|
||||||
|
|
||||||
|
|
||||||
class EventType(enum.Enum):
|
class EventType(enum.Enum):
|
||||||
@@ -97,7 +90,6 @@ class EventType(enum.Enum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
|
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
|
||||||
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
|
|
||||||
|
|
||||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||||
@@ -123,7 +115,7 @@ class StarHandlerMetadata:
|
|||||||
handler_module_path: str
|
handler_module_path: str
|
||||||
"""Handler 所在的模块路径。"""
|
"""Handler 所在的模块路径。"""
|
||||||
|
|
||||||
handler: Callable[..., Awaitable[Any]]
|
handler: Awaitable
|
||||||
"""Handler 的函数对象,应当是一个异步函数"""
|
"""Handler 的函数对象,应当是一个异步函数"""
|
||||||
|
|
||||||
event_filters: List[HandlerFilter]
|
event_filters: List[HandlerFilter]
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class PluginManager:
|
|||||||
self.updator = PluginUpdator()
|
self.updator = PluginUpdator()
|
||||||
|
|
||||||
self.context = context
|
self.context = context
|
||||||
self.context._star_manager = self # type: ignore
|
self.context._star_manager = self
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.plugin_store_path = get_astrbot_plugin_path()
|
self.plugin_store_path = get_astrbot_plugin_path()
|
||||||
@@ -478,10 +478,9 @@ class PluginManager:
|
|||||||
if isinstance(func_tool, HandoffTool):
|
if isinstance(func_tool, HandoffTool):
|
||||||
need_apply = []
|
need_apply = []
|
||||||
sub_tools = func_tool.agent.tools
|
sub_tools = func_tool.agent.tools
|
||||||
if sub_tools:
|
for sub_tool in sub_tools:
|
||||||
for sub_tool in sub_tools:
|
if isinstance(sub_tool, FunctionTool):
|
||||||
if isinstance(sub_tool, FunctionTool):
|
need_apply.append(sub_tool)
|
||||||
need_apply.append(sub_tool)
|
|
||||||
else:
|
else:
|
||||||
need_apply = [func_tool]
|
need_apply = [func_tool]
|
||||||
|
|
||||||
@@ -687,9 +686,6 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 从 star_registry 和 star_map 中删除
|
# 从 star_registry 和 star_map 中删除
|
||||||
if plugin.module_path is None or root_dir_name is None:
|
|
||||||
raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。")
|
|
||||||
|
|
||||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -795,17 +791,15 @@ class PluginManager:
|
|||||||
if star_metadata.star_cls is None:
|
if star_metadata.star_cls is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if "__del__" in star_metadata.star_cls_type.__dict__:
|
if '__del__' in star_metadata.star_cls_type.__dict__:
|
||||||
asyncio.get_event_loop().run_in_executor(
|
asyncio.get_event_loop().run_in_executor(
|
||||||
None, star_metadata.star_cls.__del__
|
None, star_metadata.star_cls.__del__
|
||||||
)
|
)
|
||||||
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
elif 'terminate' in star_metadata.star_cls_type.__dict__:
|
||||||
await star_metadata.star_cls.terminate()
|
await star_metadata.star_cls.terminate()
|
||||||
|
|
||||||
async def turn_on_plugin(self, plugin_name: str):
|
async def turn_on_plugin(self, plugin_name: str):
|
||||||
plugin = self.context.get_registered_star(plugin_name)
|
plugin = self.context.get_registered_star(plugin_name)
|
||||||
if plugin is None:
|
|
||||||
raise Exception(f"插件 {plugin_name} 不存在。")
|
|
||||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||||
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
||||||
if plugin.module_path in inactivated_plugins:
|
if plugin.module_path in inactivated_plugins:
|
||||||
|
|||||||
@@ -1,41 +1,14 @@
|
|||||||
"""
|
|
||||||
插件开发工具集
|
|
||||||
封装了许多常用的操作,方便插件开发者使用
|
|
||||||
|
|
||||||
说明:
|
|
||||||
|
|
||||||
主动发送消息: send_message(session, message_chain)
|
|
||||||
根据 session (unified_msg_origin) 主动发送消息, 前提是需要提前获得或构造 session
|
|
||||||
|
|
||||||
根据id直接主动发送消息: send_message_by_id(type, id, message_chain, platform="aiocqhttp")
|
|
||||||
根据 id (例如 qq 号, 群号等) 直接, 主动地发送消息
|
|
||||||
|
|
||||||
以上两种方式需要构造消息链, 也就是消息组件的列表
|
|
||||||
|
|
||||||
构造事件:
|
|
||||||
|
|
||||||
首先需要构造一个 AstrBotMessage 对象, 使用 create_message 方法
|
|
||||||
然后使用 create_event 方法提交事件到指定平台
|
|
||||||
"""
|
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar
|
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||||
from astrbot.core.message.components import BaseMessageComponent
|
from astrbot.core.message.components import BaseMessageComponent
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
|
from astrbot.api.platform import MessageMember, AstrBotMessage
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from astrbot.core.star.context import Context
|
from astrbot.core.star.context import Context
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import (
|
|
||||||
AiocqhttpMessageEvent,
|
|
||||||
)
|
|
||||||
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import (
|
|
||||||
AiocqhttpAdapter,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StarTools:
|
class StarTools:
|
||||||
@@ -76,82 +49,42 @@ class StarTools:
|
|||||||
Note:
|
Note:
|
||||||
qq_official(QQ官方API平台)不支持此方法
|
qq_official(QQ官方API平台)不支持此方法
|
||||||
"""
|
"""
|
||||||
if cls._context is None:
|
|
||||||
raise ValueError("StarTools not initialized")
|
|
||||||
return await cls._context.send_message(session, message_chain)
|
return await cls._context.send_message(session, message_chain)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def send_message_by_id(
|
|
||||||
cls,
|
|
||||||
type: str,
|
|
||||||
id: str,
|
|
||||||
message_chain: MessageChain,
|
|
||||||
platform: str = "aiocqhttp",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
根据 id(例如qq号, 群号等) 直接, 主动地发送消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
type (str): 消息类型, 可选: PrivateMessage, GroupMessage
|
|
||||||
id (str): 目标ID, 例如QQ号, 群号等
|
|
||||||
message_chain (MessageChain): 消息链
|
|
||||||
platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp
|
|
||||||
"""
|
|
||||||
if cls._context is None:
|
|
||||||
raise ValueError("StarTools not initialized")
|
|
||||||
platforms = cls._context.platform_manager.get_insts()
|
|
||||||
if platform == "aiocqhttp":
|
|
||||||
adapter = next(
|
|
||||||
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
|
|
||||||
)
|
|
||||||
if adapter is None:
|
|
||||||
raise ValueError("未找到适配器: AiocqhttpAdapter")
|
|
||||||
await AiocqhttpMessageEvent.send_message(
|
|
||||||
bot=adapter.bot,
|
|
||||||
message_chain=message_chain,
|
|
||||||
is_group=(type == "GroupMessage"),
|
|
||||||
session_id=id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的平台: {platform}")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_message(
|
async def create_message(
|
||||||
cls,
|
cls,
|
||||||
type: str,
|
type: str,
|
||||||
self_id: str,
|
self_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
message_id: str,
|
||||||
sender: MessageMember,
|
sender: MessageMember,
|
||||||
message: List[BaseMessageComponent],
|
message: List[BaseMessageComponent],
|
||||||
message_str: str,
|
message_str: str,
|
||||||
message_id: str = "",
|
raw_message: object,
|
||||||
raw_message: object = None,
|
|
||||||
group_id: str = "",
|
group_id: str = "",
|
||||||
) -> AstrBotMessage:
|
):
|
||||||
"""
|
"""
|
||||||
创建一个AstrBot消息对象
|
创建一个AstrBot消息对象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (str): 消息类型, 例如 "GroupMessage" "FriendMessage" "OtherMessage"
|
type (str): 消息类型
|
||||||
self_id (str): 机器人自身ID
|
self_id (str): 机器人自身ID
|
||||||
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
|
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
|
||||||
sender (MessageMember): 发送者信息, 例如 MessageMember(user_id="123456", nickname="昵称")
|
message_id (str): 消息ID
|
||||||
message (List[BaseMessageComponent]): 消息组件列表, 也就是消息链, 这个不会发给 llm, 但是会经过其他处理
|
sender (MessageMember): 发送者信息
|
||||||
message_str (str): 消息字符串, 也就是纯文本消息, 也就是发送给 llm 的消息, 与消息链一致
|
message (List[BaseMessageComponent]): 消息组件列表
|
||||||
|
message_str (str): 消息字符串
|
||||||
message_id (str): 消息ID, 构造消息时可以随意填写也可不填
|
raw_message (object): 原始消息对象
|
||||||
raw_message (object): 原始消息对象, 可以随意填写也可不填
|
|
||||||
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
|
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AstrBotMessage: 创建的消息对象
|
AstrBotMessage: 创建的消息对象
|
||||||
"""
|
"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.type = MessageType(type)
|
abm.type = type
|
||||||
abm.self_id = self_id
|
abm.self_id = self_id
|
||||||
abm.session_id = session_id
|
abm.session_id = session_id
|
||||||
if message_id == "":
|
|
||||||
message_id = uuid.uuid4().hex
|
|
||||||
abm.message_id = message_id
|
abm.message_id = message_id
|
||||||
abm.sender = sender
|
abm.sender = sender
|
||||||
abm.message = message
|
abm.message = message
|
||||||
@@ -160,39 +93,13 @@ class StarTools:
|
|||||||
abm.group_id = group_id
|
abm.group_id = group_id
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
@classmethod
|
# todo: 添加构造事件的方法
|
||||||
async def create_event(
|
# async def create_event(
|
||||||
cls, abm: AstrBotMessage, platform: str = "aiocqhttp", is_wake: bool = True
|
# self, platform: str, umo: str, sender_id: str, session_id: str
|
||||||
) -> None:
|
# ):
|
||||||
"""
|
# platform = self._context.get_platform(platform)
|
||||||
创建并提交事件到指定平台
|
|
||||||
当有需要创建一个事件, 触发某些处理流程时, 使用该方法
|
|
||||||
|
|
||||||
Args:
|
# todo: 添加找到对应平台并提交对应事件的方法
|
||||||
abm (AstrBotMessage): 要提交的消息对象, 请先使用 create_message 创建
|
|
||||||
platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp
|
|
||||||
is_wake (bool): 是否标记为唤醒事件, 默认为 True, 只有唤醒事件才会被 llm 响应
|
|
||||||
"""
|
|
||||||
if cls._context is None:
|
|
||||||
raise ValueError("StarTools not initialized")
|
|
||||||
platforms = cls._context.platform_manager.get_insts()
|
|
||||||
if platform == "aiocqhttp":
|
|
||||||
adapter = next(
|
|
||||||
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
|
|
||||||
)
|
|
||||||
if adapter is None:
|
|
||||||
raise ValueError("未找到适配器: AiocqhttpAdapter")
|
|
||||||
event = AiocqhttpMessageEvent(
|
|
||||||
message_str=abm.message_str,
|
|
||||||
message_obj=abm,
|
|
||||||
platform_meta=adapter.metadata,
|
|
||||||
session_id=abm.session_id,
|
|
||||||
bot=adapter.bot,
|
|
||||||
)
|
|
||||||
event.is_wake = is_wake
|
|
||||||
adapter.commit_event(event)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的平台: {platform}")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def activate_llm_tool(cls, name: str) -> bool:
|
def activate_llm_tool(cls, name: str) -> bool:
|
||||||
@@ -203,8 +110,6 @@ class StarTools:
|
|||||||
Args:
|
Args:
|
||||||
name (str): 工具名称
|
name (str): 工具名称
|
||||||
"""
|
"""
|
||||||
if cls._context is None:
|
|
||||||
raise ValueError("StarTools not initialized")
|
|
||||||
return cls._context.activate_llm_tool(name)
|
return cls._context.activate_llm_tool(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -215,17 +120,11 @@ class StarTools:
|
|||||||
Args:
|
Args:
|
||||||
name (str): 工具名称
|
name (str): 工具名称
|
||||||
"""
|
"""
|
||||||
if cls._context is None:
|
|
||||||
raise ValueError("StarTools not initialized")
|
|
||||||
return cls._context.deactivate_llm_tool(name)
|
return cls._context.deactivate_llm_tool(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_llm_tool(
|
def register_llm_tool(
|
||||||
cls,
|
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
|
||||||
name: str,
|
|
||||||
func_args: list,
|
|
||||||
desc: str,
|
|
||||||
func_obj: Callable[..., Awaitable[Any]],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为函数调用(function-calling/tools-use)添加工具
|
为函数调用(function-calling/tools-use)添加工具
|
||||||
@@ -236,8 +135,6 @@ class StarTools:
|
|||||||
desc (str): 工具描述
|
desc (str): 工具描述
|
||||||
func_obj (Awaitable): 函数对象,必须是异步函数
|
func_obj (Awaitable): 函数对象,必须是异步函数
|
||||||
"""
|
"""
|
||||||
if cls._context is None:
|
|
||||||
raise ValueError("StarTools not initialized")
|
|
||||||
cls._context.register_llm_tool(name, func_args, desc, func_obj)
|
cls._context.register_llm_tool(name, func_args, desc, func_obj)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -249,8 +146,6 @@ class StarTools:
|
|||||||
Args:
|
Args:
|
||||||
name (str): 工具名称
|
name (str): 工具名称
|
||||||
"""
|
"""
|
||||||
if cls._context is None:
|
|
||||||
raise ValueError("StarTools not initialized")
|
|
||||||
cls._context.unregister_llm_tool(name)
|
cls._context.unregister_llm_tool(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -274,11 +169,8 @@ class StarTools:
|
|||||||
- 创建目录失败(权限不足或其他IO错误)
|
- 创建目录失败(权限不足或其他IO错误)
|
||||||
"""
|
"""
|
||||||
if not plugin_name:
|
if not plugin_name:
|
||||||
frame = inspect.currentframe()
|
frame = inspect.currentframe().f_back
|
||||||
module = None
|
module = inspect.getmodule(frame)
|
||||||
if frame:
|
|
||||||
frame = frame.f_back
|
|
||||||
module = inspect.getmodule(frame)
|
|
||||||
|
|
||||||
if not module:
|
if not module:
|
||||||
raise RuntimeError("无法获取调用者模块信息")
|
raise RuntimeError("无法获取调用者模块信息")
|
||||||
@@ -290,12 +182,7 @@ class StarTools:
|
|||||||
|
|
||||||
plugin_name = metadata.name
|
plugin_name = metadata.name
|
||||||
|
|
||||||
if not plugin_name:
|
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
||||||
raise ValueError("无法获取插件名称")
|
|
||||||
|
|
||||||
data_dir = Path(
|
|
||||||
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
@@ -32,9 +32,6 @@ class PluginUpdator(RepoZipUpdator):
|
|||||||
if not repo_url:
|
if not repo_url:
|
||||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||||
|
|
||||||
if not plugin.root_dir_name:
|
|
||||||
raise Exception(f"插件 {plugin.name} 的根目录名未指定。")
|
|
||||||
|
|
||||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||||
|
|
||||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||||
|
|||||||
@@ -56,7 +56,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
|||||||
try:
|
try:
|
||||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
|
args = [
|
||||||
|
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||||
@@ -66,13 +68,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
|||||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def check_update(
|
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||||
self, url: str, current_version: str, consider_prerelease: bool = True
|
|
||||||
) -> ReleaseInfo:
|
|
||||||
"""检查更新"""
|
"""检查更新"""
|
||||||
return await super().check_update(
|
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||||
self.ASTRBOT_RELEASE_API, VERSION, consider_prerelease
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_releases(self) -> list:
|
async def get_releases(self) -> list:
|
||||||
return await self.fetch_release_info(self.ASTRBOT_RELEASE_API)
|
return await self.fetch_release_info(self.ASTRBOT_RELEASE_API)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user