Compare commits

..

1 Commits

179 changed files with 8370 additions and 11212 deletions

View File

@@ -1,46 +1,19 @@
<!-- 如果有的话,指定 PR 旨在解决的 ISSUE 编号。 -->
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
<!-- 如果有的话,指定这个 PR 解决的 ISSUE -->
解决了 #XYZ
fixes #XYZ
### Motivation
---
<!--解释为什么要改动-->
### Motivation / 动机
### Modifications
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
<!--简单解释你的改动-->
### Modifications / 改动点
### Check
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
### Verification Steps / 验证步骤
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤例如1. 导航到... 2. 点击...)。-->
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
### Screenshots or Test Results / 运行截图或测试结果
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
### Compatibility & Breaking Changes / 兼容性与破坏性变更
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
---
### Checklist / 检查清单
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 👀 我的更改经过良好的测试
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。
- [ ] 😮 我的更改没有引入恶意代码

View File

@@ -1,36 +0,0 @@
# Set to true to add reviewers to pull requests
addReviewers: true
# Set to true to add assignees to pull requests
addAssignees: false
# A list of reviewers to be added to pull requests (GitHub user name)
reviewers:
- Soulter
- Raven95676
- Larch-C
- anka-afk
- advent259141
# - zouyonghe
# A number of reviewers added to the pull request
# Set 0 to add all the reviewers (default: 0)
numberOfReviewers: 2
# A list of assignees, overrides reviewers if set
# assignees:
# - assigneeA
# A number of assignees to add to the pull request
# Set to 0 to add all of the assignees.
# Uses numberOfReviewers if unset.
# numberOfAssignees: 2
# A list of keywords to be skipped the process that add reviewers if pull requests include it
skipKeywords:
- wip
- draft
# A list of users to be skipped by both the add reviewers and add assignees processes
# skipUsers:
# - dependabot[bot]

View File

@@ -73,7 +73,7 @@ jobs:
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: '3.10'

View File

@@ -1,34 +0,0 @@
name: Code Format Check
on:
pull_request:
branches: [ master ]
push:
branches: [ master ]
jobs:
format-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
- name: Install UV
run: pip install uv
- name: Install dependencies
run: uv sync
- name: Check code formatting with ruff
run: |
uv run ruff format --check .
- name: Check code style with ruff
run: |
uv run ruff check .

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
- name: Install dependencies
run: |

View File

@@ -37,7 +37,6 @@ jobs:
!dist/**/*.md
- name: Create GitHub Release
if: github.event_name == 'push'
uses: ncipollo/release-action@v1
with:
tag: release-${{ github.sha }}

View File

@@ -27,33 +27,6 @@ jobs:
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Check if version is pre-release
id: check-prerelease
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
else
version="${{ github.ref_name }}"
fi
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
echo "is_prerelease=true" >> $GITHUB_OUTPUT
echo "Version $version is a pre-release, will not push latest tag"
else
echo "is_prerelease=false" >> $GITHUB_OUTPUT
echo "Version $version is a stable release, will push latest tag"
fi
- name: Build Dashboard
run: |
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
- name: Set QEMU
uses: docker/setup-qemu-action@v3
@@ -80,9 +53,9 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
ghcr.io/soulter/astrbot:latest
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- name: Post build notifications

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v10
- uses: actions/stale@v9
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'

View File

@@ -6,6 +6,8 @@
<div align="center">
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
@@ -14,18 +16,18 @@
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7日消息量&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
## 主要功能
## 主要功能
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
@@ -33,7 +35,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
5. **WebUI**。可视化配置和管理机器人,功能齐全。
## 部署方式
## ✨ 使用方式
#### Docker 部署
@@ -77,7 +79,9 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
#### 手动部署
首先安装 uv
> 推荐使用 `uv`。
首先,安装 uv
```bash
pip install uv
@@ -92,25 +96,6 @@ uv run main.py
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## 🌍 社区
### QQ 群组
- 1 群322154837
- 3 群630166526
- 5 群822130018
- 6 群753075035
- 开发者群975206796
- 开发者群备份295657329
### Telegram 群组
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群组
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ⚡ 消息平台支持情况
| 平台 | 支持性 |
@@ -127,20 +112,22 @@ uv run main.py
| Discord | ✔ |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| Satori | |
| Misskey | |
| 微信对话开放平台 | 🚧 |
| WhatsApp | 🚧 |
| 小爱音响 | 🚧 |
## ⚡ 提供商支持情况
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | 文本生成 | |
| Google Gemini | ✔ | 文本生成 | |
| OpenAI API | ✔ | 文本生成 | 支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
@@ -156,6 +143,7 @@ uv run main.py
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
@@ -174,6 +162,39 @@ pip install pre-commit
pre-commit install
```
## 🌟 支持
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
## ✨ Demo
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器Beta 测试✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
_✨ WebUI ✨_
</div>
</details>
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
@@ -182,18 +203,10 @@ pre-commit install
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
此外,本项目的诞生离不开以下开源项目的帮助
此外,本项目的诞生离不开以下开源项目:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
另外,一些同类型其他的活跃开源 Bot 项目:
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History
@@ -201,11 +214,13 @@ pre-commit install
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
</details>
![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097)
_私は、高性能ですから!_

View File

@@ -7,7 +7,6 @@ from astrbot.core.star.register import (
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_platform_loaded as on_platform_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
@@ -42,7 +41,6 @@ __all__ = [
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_platform_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",

View File

@@ -37,10 +37,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
):
click.echo("正在安装管理面板...")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
click.echo("管理面板安装完成")
@@ -53,10 +50,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
@@ -65,10 +59,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
click.echo("初始化管理面板目录...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"),
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
)
click.echo("管理面板初始化完成")
except Exception as e:

View File

@@ -124,17 +124,15 @@ def build_plug_list(plugins_dir: Path) -> list:
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
result.append(
{
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
}
)
result.append({
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
})
# 获取在线插件列表
online_plugins = []
@@ -144,17 +142,15 @@ def build_plug_list(plugins_dir: Path) -> list:
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append(
{
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
}
)
online_plugins.append({
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
})
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)

View File

@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str | FunctionTool] | None = None
tools: list[str, FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -92,7 +92,7 @@ class MCPClient:
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name: str | None = None
self.name = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
@@ -198,8 +198,6 @@ class MCPClient:
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response

View File

@@ -2,7 +2,6 @@ from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict):
chain: MessageChain

View File

@@ -14,5 +14,4 @@ class ContextWrapper(Generic[TContext]):
context: TContext
event: AstrMessageEvent
NoContext = ContextWrapper[None]

View File

@@ -258,7 +258,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain(
type="tool_direct_result"
).base64_image(resource.blob)
).base64_image(res.content[0].data)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -269,6 +269,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain().message("返回的数据类型不受支持。")
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
@@ -278,17 +289,27 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else:
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool, func_tool_args, None
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
self.run_context.event.clear_result()
except Exception as e:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Callable, Literal, Any, Optional
from typing import Awaitable, Literal, Any, Optional
from .mcp_client import MCPClient
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str
name: str | None = None
parameters: dict | None = None
description: str | None = None
handler: Callable[..., Awaitable[Any]] | None = None
handler: Awaitable | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
@@ -51,7 +51,7 @@ class ToolSet:
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] | None = None):
def __init__(self, tools: list[FunctionTool] = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
@@ -79,13 +79,7 @@ class ToolSet:
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
@@ -110,7 +104,7 @@ class ToolSet:
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> FunctionTool | None:
def get_func(self, name: str) -> list[FunctionTool]:
"""Get all function tools."""
return self.get_tool(name)
@@ -131,11 +125,7 @@ class ToolSet:
},
}
if (
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
if tool.parameters.get("properties") or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
@@ -145,14 +135,14 @@ class ToolSet:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": input_schema,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
}
result.append(tool_def)
return result
@@ -220,15 +210,14 @@ class ToolSet:
return result
tools = []
for tool in self.tools:
d = {
tools = [
{
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)
for tool in self.tools
]
declarations = {}
if tools:

View File

@@ -36,21 +36,13 @@ class AstrBotConfigManager:
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
self.abconf_data = None
self._load_all_configs()
def _get_abconf_data(self) -> dict:
"""获取所有的 abconf 数据"""
if self.abconf_data is None:
self.abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
return self.abconf_data
def _load_all_configs(self):
"""Load all configurations from the shared preferences."""
abconf_data = self._get_abconf_data()
self.abconf_data = abconf_data
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_data.items():
filename = meta["path"]
conf_path = os.path.join(get_astrbot_config_path(), filename)
@@ -80,7 +72,9 @@ class AstrBotConfigManager:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "umop": list, "path": str, "name": str }
abconf_data = self._get_abconf_data()
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if isinstance(umo, MessageSession):
umo = str(umo)
else:
@@ -121,7 +115,6 @@ class AstrBotConfigManager:
"name": random_word,
}
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
@@ -154,7 +147,9 @@ class AstrBotConfigManager:
"""获取所有配置文件的元数据列表"""
conf_list = []
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
abconf_mapping = self._get_abconf_data()
abconf_mapping = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
for uuid_, meta in abconf_mapping.items():
conf_list.append(ConfInfo(**meta, id=uuid_))
return conf_list
@@ -223,7 +218,6 @@ class AstrBotConfigManager:
# 从映射中移除
del abconf_data[conf_id]
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功删除配置文件 {conf_id}")
return True
@@ -269,7 +263,6 @@ class AstrBotConfigManager:
# 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功更新配置文件 {conf_id} 的信息")
return True

View File

@@ -6,7 +6,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.2.1"
VERSION = "4.0.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -51,21 +51,23 @@ DEFAULT_CONFIG = {
"enable": True,
"default_provider_id": "",
"default_image_caption_provider_id": "",
"default_summarize_provider_id": "",
"context_exceed_calc_method": "token_size",
"max_token_size": 128000,
"max_context_length": 100,
"image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
"wake_prefix": "",
"web_search": False,
"websearch_provider": "default",
"websearch_tavily_key": [],
"websearch_tavily_key": "",
"web_search_link": False,
"display_reasoning_text": False,
"identifier": False,
"group_name_display": False,
"datetime_system_prompt": True,
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
@@ -104,7 +106,6 @@ DEFAULT_CONFIG = {
"t2i_strategy": "remote",
"t2i_endpoint": "",
"t2i_use_file_service": False,
"t2i_active_template": "base",
"http_proxy": "",
"no_proxy": ["localhost", "127.0.0.1", "::1"],
"dashboard": {
@@ -236,16 +237,6 @@ CONFIG_METADATA_2 = {
"discord_guild_id_for_debug": "",
"discord_activity_name": "",
},
"Misskey": {
"id": "misskey",
"type": "misskey",
"enable": False,
"misskey_instance_url": "https://misskey.example",
"misskey_token": "",
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
},
"Slack": {
"id": "slack",
"type": "slack",
@@ -258,49 +249,8 @@ CONFIG_METADATA_2 = {
"slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback",
},
"Satori": {
"id": "satori",
"type": "satori",
"enable": False,
"satori_api_base_url": "http://localhost:5140/satori/v1",
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
"satori_token": "",
"satori_auto_reconnect": True,
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
},
"items": {
"satori_api_base_url": {
"description": "Satori API 终结点",
"type": "string",
"hint": "Satori API 的基础地址。",
},
"satori_endpoint": {
"description": "Satori WebSocket 终结点",
"type": "string",
"hint": "Satori 事件的 WebSocket 端点。",
},
"satori_token": {
"description": "Satori 令牌",
"type": "string",
"hint": "用于 Satori API 身份验证的令牌。",
},
"satori_auto_reconnect": {
"description": "启用自动重连",
"type": "bool",
"hint": "断开连接时是否自动重新连接 WebSocket。",
},
"satori_heartbeat_interval": {
"description": "Satori 心跳间隔",
"type": "int",
"hint": "发送心跳消息的间隔(秒)。",
},
"satori_reconnect_delay": {
"description": "Satori 重连延迟",
"type": "int",
"hint": "尝试重新连接前的延迟时间(秒)。",
},
"slack_connection_mode": {
"description": "Slack Connection Mode",
"type": "string",
@@ -347,32 +297,6 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
},
"misskey_instance_url": {
"description": "Misskey 实例 URL",
"type": "string",
"hint": "例如 https://misskey.example填写 Bot 账号所在的 Misskey 实例地址",
},
"misskey_token": {
"description": "Misskey Access Token",
"type": "string",
"hint": "连接服务设置生成的 API 鉴权访问令牌Access token",
},
"misskey_default_visibility": {
"description": "默认帖子可见性",
"type": "string",
"options": ["public", "home", "followers"],
"hint": "机器人发帖时的默认可见性设置。public公开home主页时间线followers仅关注者。",
},
"misskey_local_only": {
"description": "仅限本站(不参与联合)",
"type": "bool",
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
},
"misskey_enable_chat": {
"description": "启用聊天消息响应",
"type": "bool",
"hint": "启用后,机器人将会监听和响应私信聊天消息",
},
"telegram_command_register": {
"description": "Telegram 命令注册",
"type": "bool",
@@ -636,7 +560,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
},
@@ -651,7 +574,6 @@ CONFIG_METADATA_2 = {
"api_base": "",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"xAI": {
@@ -664,7 +586,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Anthropic": {
@@ -694,7 +615,6 @@ CONFIG_METADATA_2 = {
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"LM Studio": {
@@ -708,7 +628,6 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "llama-3.1-8b",
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini(OpenAI兼容)": {
@@ -724,7 +643,6 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash",
"temperature": 0.4,
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini": {
@@ -765,7 +683,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"302.AI": {
@@ -778,7 +695,6 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.302.ai/v1",
"timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"硅基流动": {
@@ -794,7 +710,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"PPIO派欧云": {
@@ -810,7 +725,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
},
"custom_extra_body": {},
},
"优云智算": {
"id": "compshare",
@@ -824,7 +738,6 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "moonshotai/Kimi-K2-Instruct",
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Kimi": {
@@ -837,7 +750,6 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"智谱 AI": {
@@ -869,18 +781,6 @@ CONFIG_METADATA_2 = {
"timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
},
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
@@ -908,7 +808,6 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"FastGPT": {
@@ -920,7 +819,6 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.fastgpt.in/api/v1",
"timeout": 60,
"custom_extra_body": {},
},
"Whisper(API)": {
"id": "whisper",
@@ -971,9 +869,6 @@ CONFIG_METADATA_2 = {
"provider_type": "text_to_speech",
"enable": False,
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
"rate": "+0%",
"volume": "+0%",
"pitch": "+0Hz",
"timeout": 20,
},
"GSV TTS(本地加载)": {
@@ -1165,12 +1060,6 @@ CONFIG_METADATA_2 = {
"render_type": "checkbox",
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
},
"custom_extra_body": {
"description": "自定义请求体参数",
"type": "dict",
"items": {},
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
},
"provider": {
"type": "string",
"invisible": True,
@@ -1747,26 +1636,6 @@ CONFIG_METADATA_2 = {
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
"obvious": True,
},
"coze_api_key": {
"description": "Coze API Key",
"type": "string",
"hint": "Coze API 密钥,用于访问 Coze 服务。",
},
"bot_id": {
"description": "Bot ID",
"type": "string",
"hint": "Coze 机器人的 ID在 Coze 平台上创建机器人后获得。",
},
"coze_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
},
"auto_save_history": {
"description": "由 Coze 管理对话记录",
"type": "bool",
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
},
},
},
"provider_settings": {
@@ -1793,9 +1662,6 @@ CONFIG_METADATA_2 = {
"identifier": {
"type": "bool",
},
"group_name_display": {
"type": "bool",
},
"datetime_system_prompt": {
"type": "bool",
},
@@ -1969,37 +1835,51 @@ CONFIG_METADATA_3 = {
"_special": "select_provider",
"hint": "留空时使用第一个模型。",
},
"provider_settings.default_summarize_provider_id": {
"description": "默认对话总结模型",
"type": "string",
"_special": "select_provider",
"hint": "留空代表不进行对话总结。可用于压缩上下文以减少 token 用量,并一定程度上保持历史聊天记忆。",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
},
"provider_stt_settings.enable": {
"description": "默认启用语音转文本",
"type": "bool",
},
"provider_stt_settings.provider_id": {
"description": "语音转文本模型",
"type": "string",
"hint": "留空代表不使用。",
"_special": "select_provider_stt",
"condition": {
"provider_stt_settings.enable": True,
},
},
"provider_tts_settings.enable": {
"description": "默认启用文本转语音",
"type": "bool",
},
"provider_tts_settings.provider_id": {
"description": "文本转语音模型",
"type": "string",
"hint": "留空代表不使用。",
"_special": "select_provider_tts",
},
"provider_settings.context_exceed_calc_method": {
"description": "上下文超限的触发策略",
"type": "string",
"options": ["token_size", "context_length"],
"labels": ["基于 Token 长度(估算)", "基于对话轮数"],
"hint": "如配置了对话总结模型,则触发时总结对话内容,否则丢弃最旧部分。"
},
"provider_settings.max_context_length": {
"description": "对话轮数上限",
"type": "int",
"condition": {
"provider_tts_settings.enable": True,
},
"provider_settings.context_exceed_calc_method": "context_length"
}
},
"provider_settings.max_token_size": {
"description": "Token 长度上限(估算)",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分。",
"condition": {
"provider_settings.context_exceed_calc_method": "token_size"
}
},
"provider_settings.image_caption_prompt": {
"description": "图片转述提示词",
@@ -2044,9 +1924,7 @@ CONFIG_METADATA_3 = {
},
"provider_settings.websearch_tavily_key": {
"description": "Tavily API Key",
"type": "list",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"type": "string",
"condition": {
"provider_settings.websearch_provider": "tavily",
},
@@ -2066,14 +1944,9 @@ CONFIG_METADATA_3 = {
"type": "bool",
},
"provider_settings.identifier": {
"description": "用户识别",
"description": "用户感知",
"type": "bool",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",
@@ -2094,11 +1967,6 @@ CONFIG_METADATA_3 = {
"description": "不支持流式回复的平台采取分段输出",
"type": "bool",
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。",
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
@@ -2221,41 +2089,41 @@ CONFIG_METADATA_3 = {
"description": "内容安全",
"type": "object",
"items": {
"content_safety.also_use_in_response": {
"platform_settings.content_safety.also_use_in_response": {
"description": "同时检查模型的响应内容",
"type": "bool",
},
"content_safety.baidu_aip.enable": {
"platform_settings.content_safety.baidu_aip.enable": {
"description": "使用百度内容安全审核",
"type": "bool",
"hint": "您需要手动安装 baidu-aip 库。",
},
"content_safety.baidu_aip.app_id": {
"platform_settings.content_safety.baidu_aip.app_id": {
"description": "App ID",
"type": "string",
"condition": {
"content_safety.baidu_aip.enable": True,
"platform_settings.content_safety.baidu_aip.enable": True,
},
},
"content_safety.baidu_aip.api_key": {
"platform_settings.content_safety.baidu_aip.api_key": {
"description": "API Key",
"type": "string",
"condition": {
"content_safety.baidu_aip.enable": True,
"platform_settings.content_safety.baidu_aip.enable": True,
},
},
"content_safety.baidu_aip.secret_key": {
"platform_settings.content_safety.baidu_aip.secret_key": {
"description": "Secret Key",
"type": "string",
"condition": {
"content_safety.baidu_aip.enable": True,
"platform_settings.content_safety.baidu_aip.enable": True,
},
},
"content_safety.internal_keywords.enable": {
"platform_settings.content_safety.internal_keywords.enable": {
"description": "关键词检查",
"type": "bool",
},
"content_safety.internal_keywords.extra_keywords": {
"platform_settings.content_safety.internal_keywords.extra_keywords": {
"description": "额外关键词",
"type": "list",
"items": {"type": "string"},
@@ -2446,13 +2314,7 @@ CONFIG_METADATA_3_SYSTEM = {
"condition": {
"t2i_strategy": "remote",
},
"_special": "t2i_template",
},
"t2i_active_template": {
"description": "当前应用的文转图渲染模板",
"type": "string",
"hint": "此处的值由文转图模板管理页面进行维护。",
"invisible": True,
"_special": "t2i_template"
},
"log_level": {
"description": "控制台日志级别",
@@ -2485,11 +2347,6 @@ CONFIG_METADATA_3_SYSTEM = {
"type": "string",
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
},
"no_proxy": {
"description": "直连地址列表",
"type": "list",
"items": {"type": "string"},
},
},
}
},

View File

@@ -87,25 +87,17 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid == conversation_id:
if f:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID

View File

@@ -154,11 +154,6 @@ class BaseDatabase(abc.ABC):
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def delete_conversations_by_user_id(self, user_id: str) -> None:
"""Delete all conversations for a specific user."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,

View File

@@ -53,7 +53,7 @@ async def do_migration_v4(
await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移
await migration_preferences(db_helper, platform_id_map)
await migration_preferences(db_helper,platform_id_map)
# 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map)

View File

@@ -5,7 +5,6 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
@@ -43,5 +42,4 @@ class SharedPreferences:
self._data.clear()
self._save_preferences()
sp = SharedPreferences()

View File

@@ -4,7 +4,6 @@ from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass
class Conversation:
"""LLM 对话存储
@@ -77,7 +76,7 @@ PRAGMA encoding = 'UTF-8';
"""
class SQLiteDatabase:
class SQLiteDatabase():
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path

View File

@@ -18,7 +18,6 @@ from astrbot.core.db.po import (
from sqlalchemy import select, update, delete, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
from sqlalchemy import or_
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
@@ -154,22 +153,8 @@ class SQLiteDatabase(BaseDatabase):
ConversationV2.platform_id.in_(platform_ids)
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
or_(
ConversationV2.title.ilike(f"%{search_query}%"),
ConversationV2.content.ilike(f"%{search_query}%"),
ConversationV2.user_id.ilike(f"%{search_query}%"),
)
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
ConversationV2.platform_id.in_(kwargs["platforms"])
ConversationV2.title.ilike(f"%{search_query}%")
)
# Get total count matching the filters
@@ -249,14 +234,6 @@ class SQLiteDatabase(BaseDatabase):
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.user_id == user_id)
)
async def insert_platform_message_history(
self,
platform_id,

View File

@@ -1,3 +1,3 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]
__all__ = ["FaissVecDB"]

View File

@@ -113,8 +113,7 @@ class FaissVecDB(BaseVecDB):
reranked_results, key=lambda x: x.relevance_score, reverse=True
)
top_k_results = [
top_k_results[reranked_result.index]
for reranked_result in reranked_results
top_k_results[reranked_result.index] for reranked_result in reranked_results
]
return top_k_results

View File

@@ -22,7 +22,6 @@ class InitialLoader:
self.db = db
self.logger = logger
self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
@@ -36,10 +35,8 @@ class InitialLoader:
core_task = core_lifecycle.start()
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
)
task = asyncio.gather(
core_task, self.dashboard_server.run()

View File

@@ -37,7 +37,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
class ComponentType(str, Enum):
class ComponentType(Enum):
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
@@ -108,7 +108,7 @@ class BaseMessageComponent(BaseModel):
class Plain(BaseMessageComponent):
type = ComponentType.Plain
type: ComponentType = "Plain"
text: str
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
@@ -128,9 +128,8 @@ class Plain(BaseMessageComponent):
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent):
type = ComponentType.Face
type: ComponentType = "Face"
id: int
def __init__(self, **_):
@@ -138,7 +137,7 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent):
type = ComponentType.Record
type: ComponentType = "Record"
file: T.Optional[str] = ""
magic: T.Optional[bool] = False
url: T.Optional[str] = ""
@@ -165,24 +164,19 @@ class Record(BaseMessageComponent):
return Record(file=url, **_)
raise Exception("not a valid url")
@staticmethod
def fromBase64(bs64_data: str, **_):
return Record(file=f"base64://{bs64_data}", **_)
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
return self.file[8:]
elif self.file.startswith("http"):
if self.file and self.file.startswith("file:///"):
file_path = self.file[8:]
return file_path
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file.startswith("base64://"):
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -191,7 +185,8 @@ class Record(BaseMessageComponent):
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
return os.path.abspath(self.file)
file_path = self.file
return os.path.abspath(file_path)
else:
raise Exception(f"not a valid file: {self.file}")
@@ -202,14 +197,12 @@ class Record(BaseMessageComponent):
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
if self.file and self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file.startswith("http"):
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file.startswith("base64://"):
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
@@ -243,7 +236,7 @@ class Record(BaseMessageComponent):
class Video(BaseMessageComponent):
type = ComponentType.Video
type: ComponentType = "Video"
file: str
cover: T.Optional[str] = ""
c: T.Optional[int] = 2
@@ -329,7 +322,7 @@ class Video(BaseMessageComponent):
class At(BaseMessageComponent):
type = ComponentType.At
type: ComponentType = "At"
qq: T.Union[int, str] # 此处str为all时代表所有人
name: T.Optional[str] = ""
@@ -351,28 +344,28 @@ class AtAll(At):
class RPS(BaseMessageComponent): # TODO
type = ComponentType.RPS
type: ComponentType = "RPS"
def __init__(self, **_):
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
type = ComponentType.Dice
type: ComponentType = "Dice"
def __init__(self, **_):
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
type = ComponentType.Shake
type: ComponentType = "Shake"
def __init__(self, **_):
super().__init__(**_)
class Anonymous(BaseMessageComponent): # TODO
type = ComponentType.Anonymous
type: ComponentType = "Anonymous"
ignore: T.Optional[bool] = False
def __init__(self, **_):
@@ -380,7 +373,7 @@ class Anonymous(BaseMessageComponent): # TODO
class Share(BaseMessageComponent):
type = ComponentType.Share
type: ComponentType = "Share"
url: str
title: str
content: T.Optional[str] = ""
@@ -391,7 +384,7 @@ class Share(BaseMessageComponent):
class Contact(BaseMessageComponent): # TODO
type = ComponentType.Contact
type: ComponentType = "Contact"
_type: str # type 字段冲突
id: T.Optional[int] = 0
@@ -400,7 +393,7 @@ class Contact(BaseMessageComponent): # TODO
class Location(BaseMessageComponent): # TODO
type = ComponentType.Location
type: ComponentType = "Location"
lat: float
lon: float
title: T.Optional[str] = ""
@@ -411,7 +404,7 @@ class Location(BaseMessageComponent): # TODO
class Music(BaseMessageComponent):
type = ComponentType.Music
type: ComponentType = "Music"
_type: str
id: T.Optional[int] = 0
url: T.Optional[str] = ""
@@ -428,7 +421,7 @@ class Music(BaseMessageComponent):
class Image(BaseMessageComponent):
type = ComponentType.Image
type: ComponentType = "Image"
file: T.Optional[str] = ""
_type: T.Optional[str] = ""
subType: T.Optional[int] = 0
@@ -471,15 +464,14 @@ class Image(BaseMessageComponent):
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
return url[8:]
elif url.startswith("http"):
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
image_file_path = url[8:]
return image_file_path
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url.startswith("base64://"):
elif url and url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -488,7 +480,8 @@ class Image(BaseMessageComponent):
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
return os.path.abspath(url)
image_file_path = url
return os.path.abspath(image_file_path)
else:
raise Exception(f"not a valid file: {url}")
@@ -499,15 +492,13 @@ class Image(BaseMessageComponent):
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url.startswith("http"):
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url.startswith("base64://"):
elif url and url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
@@ -541,7 +532,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type = ComponentType.Reply
type: ComponentType = "Reply"
id: T.Union[str, int]
"""所引用的消息 ID"""
chain: T.Optional[T.List["BaseMessageComponent"]] = []
@@ -567,7 +558,7 @@ class Reply(BaseMessageComponent):
class RedBag(BaseMessageComponent):
type = ComponentType.RedBag
type: ComponentType = "RedBag"
title: str
def __init__(self, **_):
@@ -575,7 +566,7 @@ class RedBag(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
type: str = ""
id: T.Optional[int] = 0
qq: T.Optional[int] = 0
@@ -585,7 +576,7 @@ class Poke(BaseMessageComponent):
class Forward(BaseMessageComponent):
type = ComponentType.Forward
type: ComponentType = "Forward"
id: str
def __init__(self, **_):
@@ -595,7 +586,7 @@ class Forward(BaseMessageComponent):
class Node(BaseMessageComponent):
"""群合并转发消息"""
type = ComponentType.Node
type: ComponentType = "Node"
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
@@ -647,7 +638,7 @@ class Node(BaseMessageComponent):
class Nodes(BaseMessageComponent):
type = ComponentType.Nodes
type: ComponentType = "Nodes"
nodes: T.List[Node]
def __init__(self, nodes: T.List[Node], **_):
@@ -673,7 +664,7 @@ class Nodes(BaseMessageComponent):
class Xml(BaseMessageComponent):
type = ComponentType.Xml
type: ComponentType = "Xml"
data: str
resid: T.Optional[int] = 0
@@ -682,7 +673,7 @@ class Xml(BaseMessageComponent):
class Json(BaseMessageComponent):
type = ComponentType.Json
type: ComponentType = "Json"
data: T.Union[str, dict]
resid: T.Optional[int] = 0
@@ -693,7 +684,7 @@ class Json(BaseMessageComponent):
class CardImage(BaseMessageComponent):
type = ComponentType.CardImage
type: ComponentType = "CardImage"
file: str
cache: T.Optional[bool] = True
minwidth: T.Optional[int] = 400
@@ -712,7 +703,7 @@ class CardImage(BaseMessageComponent):
class TTS(BaseMessageComponent):
type = ComponentType.TTS
type: ComponentType = "TTS"
text: str
def __init__(self, **_):
@@ -720,7 +711,7 @@ class TTS(BaseMessageComponent):
class Unknown(BaseMessageComponent):
type = ComponentType.Unknown
type: ComponentType = "Unknown"
text: str
def toString(self):
@@ -732,7 +723,7 @@ class File(BaseMessageComponent):
文件消息段
"""
type = ComponentType.File
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file_: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
@@ -862,7 +853,7 @@ class File(BaseMessageComponent):
class WechatEmoji(BaseMessageComponent):
type = ComponentType.WechatEmoji
type: ComponentType = "WechatEmoji"
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""

View File

@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str | None = None
self, event: AstrMessageEvent, check_text: str = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()

View File

@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
def check(self, content: str) -> tuple[bool, str]:
def check(self, content: str):
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
return False, ""

View File

@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
def check(self, content: str) -> tuple[bool, str]:
def check(self, content: str) -> bool:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"

View File

@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]],
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -36,9 +36,6 @@ async def call_handler(
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
@@ -80,7 +77,7 @@ async def call_event_hook(
Returns:
bool: 如果事件被终止,返回 True
#"""
# """
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, plugins_name=event.plugins_name
)

View File

@@ -7,7 +7,6 @@ import copy
import json
import traceback
from typing import AsyncGenerator, Union
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
@@ -134,15 +133,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
@@ -158,7 +148,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
yield mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
@@ -210,11 +200,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
res = await tool.mcp_client.session.call_tool(
name=tool.name,
arguments=tool_args,
)
@@ -285,12 +271,19 @@ async def run_agent(
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
class LLMRequestSubStage(Stage):
@@ -306,9 +299,7 @@ class LLMRequestSubStage(Stage):
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.max_step: int = settings.get("max_agent_step", 10)
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
for bwp in self.bot_wake_prefixs:
@@ -332,7 +323,7 @@ class LLMRequestSubStage(Stage):
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
async def _get_session_conv(self, event: AstrMessageEvent):
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
@@ -344,8 +335,6 @@ class LLMRequestSubStage(Stage):
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def process(
@@ -445,18 +434,13 @@ class LLMRequestSubStage(Stage):
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
)
logger.debug(f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。")
req.func_tool = None
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
plugin = star_map.get(tool.handler_module_path)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
@@ -517,14 +501,6 @@ class LLMRequestSubStage(Stage):
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -537,23 +513,7 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:]
if not latest_pair:
return
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",

View File

@@ -34,14 +34,12 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
try:
if handler.handler_module_path not in star_map:
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
wrapper = call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
@@ -51,7 +49,7 @@ class StarRequestSubStage(Stage):
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()

View File

@@ -1,15 +1,17 @@
import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext, call_event_hook
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.components import BaseMessageComponent, ComponentType
from astrbot.core.star.star_handler import EventType
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@@ -112,43 +114,6 @@ class RespondStage(Stage):
# 如果所有组件都为空
return True
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
"""检查是否需要分段回复"""
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
return False
if event.get_platform_name() in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
return False
return True
def _extract_comp(
self,
raw_chain: list[BaseMessageComponent],
extract_types: set[ComponentType],
modify_raw_chain: bool = True,
):
extracted = []
if modify_raw_chain:
remaining = []
for comp in raw_chain:
if comp.type in extract_types:
extracted.append(comp)
else:
remaining.append(comp)
raw_chain[:] = remaining
else:
extracted = [comp for comp in raw_chain if comp.type in extract_types]
return extracted
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -158,14 +123,7 @@ class RespondStage(Stage):
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
logger.info(
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
if result.result_content_type == ResultContentType.STREAMING_RESULT:
if result.async_stream is None:
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
@@ -190,71 +148,87 @@ class RespondStage(Stage):
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
# 发送消息链
# Record 需要强制单独发送
need_separately = {ComponentType.Record}
if self.is_seg_reply_required(event):
header_comps = self._extract_comp(
result.chain,
{ComponentType.Reply, ComponentType.At},
modify_raw_chain=True,
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
non_record_comps = [
c for c in result.chain if not isinstance(c, Comp.Record)
]
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
)
if not result.chain or len(result.chain) == 0:
# may fix #2670
logger.warning(
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
)
return
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
except Exception as e:
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
else:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}
for comp in result.chain
):
# may fix #2670
logger.warning(
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
)
return
sep_comps = self._extract_comp(
result.chain,
need_separately,
modify_raw_chain=True,
)
for comp in sep_comps:
chain = MessageChain([comp])
for rcomp in record_comps:
try:
await event.send(chain)
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
chain = MessageChain(result.chain)
if result.chain and len(result.chain) > 0:
try:
await event.send(chain)
except Exception as e:
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
logger.error(f"发送消息失败: {e} chain: {result.chain}")
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
return
try:
await event.send(MessageChain(non_record_comps))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
event.clear_result()

View File

@@ -36,7 +36,6 @@ class ResultDecorateStage(Stage):
self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
self.t2i_use_network = self.t2i_strategy == "remote"
self.t2i_active_template = ctx.astrbot_config["t2i_active_template"]
self.forward_threshold = ctx.astrbot_config["platform_settings"][
"forward_threshold"
@@ -248,10 +247,7 @@ class ResultDecorateStage(Stage):
render_start = time.time()
try:
url = await html_renderer.render_t2i(
plain_str,
return_url=True,
use_network=self.t2i_use_network,
template_name=self.t2i_active_template,
plain_str, return_url=True, use_network=self.t2i_use_network
)
except BaseException:
logger.error("文本转图片失败,使用文本发送。")

View File

@@ -11,8 +11,7 @@ class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
pass
async def process(
self, event: AstrMessageEvent
@@ -20,14 +19,4 @@ class SessionStatusCheckStage(Stage):
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
conv_id = await self.conv_mgr.get_curr_conversation_id(
event.unified_msg_origin
)
if not conv_id:
await self.conv_mgr.new_conversation(
event.unified_msg_origin, platform_id=event.get_platform_id()
)
event.stop_event()

View File

@@ -5,7 +5,6 @@ from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -171,15 +170,11 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
is_group_cmd_handler = any(
isinstance(f, CommandGroupFilter) for f in handler.event_filters
)
if not is_group_cmd_handler:
activated_handlers.append(handler)
if "parsed_params" in event.get_extra(default={}):
handlers_parsed_params[handler.handler_full_name] = (
event.get_extra("parsed_params")
)
activated_handlers.append(handler)
if "parsed_params" in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
event._extras.pop("parsed_params", None)

View File

@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
from typing import List, Union, Optional, AsyncGenerator
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -24,9 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
_VT = TypeVar("_VT")
from .message_session import MessageSession, MessageSesion # noqa
class AstrMessageEvent(abc.ABC):
@@ -51,7 +49,7 @@ class AstrMessageEvent(abc.ABC):
"""是否唤醒(是否通过 WakingStage)"""
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras: dict[str, Any] = {}
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
message_type=message_obj.type,
@@ -59,7 +57,7 @@ class AstrMessageEvent(abc.ABC):
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self._result: MessageEventResult | None = None
self._result: MessageEventResult = None
"""消息事件的结果"""
self._has_send_oper = False
@@ -175,15 +173,13 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(
self, key: str | None = None, default: _VT = None
) -> dict[str, Any] | _VT:
def get_extra(self, key=None):
"""
获取额外的信息。
"""
if key is None:
return self._extras
return self._extras.get(key, default)
return self._extras.get(key, None)
def clear_extra(self):
"""

View File

@@ -55,7 +55,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group: Group # 群组
group_id: str = "" # 群组id如果为私聊则为空
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -64,28 +64,6 @@ class AstrBotMessage:
def __init__(self) -> None:
self.timestamp = int(time.time())
self.group = None
def __str__(self) -> str:
return str(self.__dict__)
@property
def group_id(self) -> str:
"""
向后兼容的 group_id 属性
群组id如果为私聊则为空
"""
if self.group:
return self.group.group_id
return ""
@group_id.setter
def group_id(self, value: str):
"""设置 group_id"""
if value:
if self.group:
self.group.group_id = value
else:
self.group = Group(group_id=value)
else:
self.group = None

View File

@@ -6,7 +6,6 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType
from .sources.webchat.webchat_adapter import WebChatAdapter
@@ -67,39 +66,25 @@ class PlatformManager:
WeChatPadProAdapter, # noqa: F401
)
case "lark":
from .sources.lark.lark_adapter import (
LarkPlatformAdapter, # noqa: F401
)
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram":
from .sources.telegram.tg_adapter import (
TelegramPlatformAdapter, # noqa: F401
)
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
case "wecom":
from .sources.wecom.wecom_adapter import (
WecomPlatformAdapter, # noqa: F401
)
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa: F401
WeixinOfficialAccountPlatformAdapter, # noqa
)
case "discord":
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "misskey":
from .sources.misskey.misskey_adapter import (
MisskeyPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
case "satori":
from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
@@ -128,17 +113,6 @@ class PlatformManager:
)
)
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnPlatformLoadedEvent
)
for handler in handlers:
try:
logger.info(
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(self, task: asyncio.Task):
try:

View File

@@ -67,19 +67,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
session_id: str,
messages: list[dict],
):
# session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None
if is_group and isinstance(session_id, int):
await bot.send_group_msg(group_id=session_id, message=messages)
elif not is_group and isinstance(session_id, int):
await bot.send_private_msg(user_id=session_id, message=messages)
elif isinstance(event, Event): # 最后兜底
if event:
await bot.send(event=event, message=messages)
elif is_group:
await bot.send_group_msg(group_id=session_id, message=messages)
else:
raise ValueError(
f"无法发送消息:缺少有效的数字 session_id({session_id}) 或 event({event})"
)
await bot.send_private_msg(user_id=session_id, message=messages)
@classmethod
async def send_message(
@@ -90,15 +83,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
is_group: bool = False,
session_id: str = None,
):
"""发送消息至 QQ 协议端aiocqhttp
Args:
bot (CQHttp): aiocqhttp 机器人实例
message_chain (MessageChain): 要发送的消息链
event (Event | None, optional): aiocqhttp 事件对象.
is_group (bool, optional): 是否为群消息.
session_id (str | None, optional): 会话 ID群号或 QQ 号
"""
"""发送消息"""
# 转发消息、文件消息不能和普通消息混在一起发送
send_one_by_one = any(
@@ -137,15 +122,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
"""发送消息"""
event = getattr(self.message_obj, "raw_message", None)
is_group = bool(self.get_group_id())
session_id = self.get_group_id() if is_group else self.get_sender_id()
event = self.message_obj.raw_message
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
is_group = False
if self.get_group_id():
is_group = True
session_id = self.get_group_id()
else:
session_id = self.get_sender_id()
await self.send_message(
bot=self.bot,
message_chain=message,
event=event, # 不强制要求一定是 Event
event=event,
is_group=is_group,
session_id=session_id,
)

View File

@@ -182,13 +182,11 @@ class AiocqhttpAdapter(Platform):
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
str(event.sender["user_id"]),
event.sender.get("card") or event.sender.get("nickname", "N/A"),
str(event.sender["user_id"]), event.sender["nickname"]
)
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
@@ -310,22 +308,13 @@ class AiocqhttpAdapter(Platform):
continue
at_info = await self.bot.call_action(
action="get_group_member_info",
group_id=event.group_id,
action="get_stranger_info",
user_id=int(m["data"]["qq"]),
no_cache=False,
)
if at_info:
nickname = at_info.get("card", "")
if nickname == "":
at_info = await self.bot.call_action(
action="get_stranger_info",
user_id=int(m["data"]["qq"]),
no_cache=False,
)
nickname = at_info.get("nick", "") or at_info.get(
"nickname", ""
)
nickname = at_info.get("nick", "") or at_info.get(
"nickname", ""
)
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
abm.message.append(

View File

@@ -54,9 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
logger.debug(f"send image: {ret}")
except Exception as e:
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
logger.error(f"钉钉图片处理失败: {e}")
logger.warning(f"跳过图片发送: {image_path}")
continue
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)

View File

@@ -41,8 +41,7 @@ class DiscordBotClient(discord.Bot):
await self.on_ready_once_callback()
except Exception as e:
logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
)
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
@@ -91,6 +90,7 @@ class DiscordBotClient(discord.Bot):
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type

View File

@@ -79,12 +79,9 @@ class DiscordButton(BaseMessageComponent):
self.url = url
self.disabled = disabled
class DiscordReference(BaseMessageComponent):
"""Discord引用组件"""
type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str):
self.message_id = message_id
self.channel_id = channel_id
@@ -101,6 +98,7 @@ class DiscordView(BaseMessageComponent):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)

View File

@@ -53,13 +53,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
# 解析消息链为 Discord 所需的对象
try:
(
content,
files,
view,
embeds,
reference_message_id,
) = await self._parse_to_discord(message)
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
except Exception as e:
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
return
@@ -212,7 +206,8 @@ class DiscordPlatformEvent(AstrMessageEvent):
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
files.append(
discord.File(BytesIO(file_bytes), filename=i.name)
discord.File(BytesIO(file_bytes),
filename=i.name)
)
else:
logger.warning(

View File

@@ -1,391 +0,0 @@
import asyncio
import json
from typing import Dict, Any, Optional, Awaitable
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import (
AstrBotMessage,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
import astrbot.api.message_components as Comp
from .misskey_api import MisskeyAPI
from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import (
serialize_message_chain,
resolve_message_visibility,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
process_files,
extract_sender_info,
create_base_message,
process_at_mention,
cache_user_info,
cache_room_info,
)
@register_platform_adapter("misskey", "Misskey 平台适配器")
class MisskeyPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config or {}
self.settings = platform_settings or {}
self.instance_url = self.config.get("misskey_instance_url", "")
self.access_token = self.config.get("misskey_token", "")
self.max_message_length = self.config.get("max_message_length", 3000)
self.default_visibility = self.config.get(
"misskey_default_visibility", "public"
)
self.local_only = self.config.get("misskey_local_only", False)
self.enable_chat = self.config.get("misskey_enable_chat", True)
self.unique_session = platform_settings["unique_session"]
self.api: Optional[MisskeyAPI] = None
self._running = False
self.client_self_id = ""
self._bot_username = ""
self._user_cache = {}
def meta(self) -> PlatformMetadata:
default_config = {
"misskey_instance_url": "",
"misskey_token": "",
"max_message_length": 3000,
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
}
default_config.update(self.config)
return PlatformMetadata(
name="misskey",
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
)
async def run(self):
if not self.instance_url or not self.access_token:
logger.error("[Misskey] 配置不完整,无法启动")
return
self.api = MisskeyAPI(self.instance_url, self.access_token)
self._running = True
try:
user_info = await self.api.get_current_user()
self.client_self_id = str(user_info.get("id", ""))
self._bot_username = user_info.get("username", "")
logger.info(
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
)
except Exception as e:
logger.error(f"[Misskey] 获取用户信息失败: {e}")
self._running = False
return
await self._start_websocket_connection()
async def _start_websocket_connection(self):
backoff_delay = 1.0
max_backoff = 300.0
backoff_multiplier = 1.5
connection_attempts = 0
while self._running:
try:
connection_attempts += 1
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
break
streaming = self.api.get_streaming_client()
streaming.add_message_handler("notification", self._handle_notification)
if self.enable_chat:
streaming.add_message_handler(
"newChatMessage", self._handle_chat_message
)
streaming.add_message_handler("_debug", self._debug_handler)
if await streaming.connect():
logger.info(
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
)
connection_attempts = 0 # 重置计数器
await streaming.subscribe_channel("main")
if self.enable_chat:
await streaming.subscribe_channel("messaging")
await streaming.subscribe_channel("messagingIndex")
logger.info("[Misskey] 聊天频道已订阅")
backoff_delay = 1.0 # 重置延迟
await streaming.listen()
else:
logger.error(
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
)
except Exception as e:
logger.error(
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
)
if self._running:
logger.info(
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
)
await asyncio.sleep(backoff_delay)
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
async def _handle_notification(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
notification_type = data.get("type")
if notification_type in ["mention", "reply", "quote"]:
note = data.get("note")
if note and self._is_bot_mentioned(note):
logger.info(
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
)
message = await self.convert_message(note)
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理通知失败: {e}")
async def _handle_chat_message(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
sender_id = str(
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
)
if sender_id == self.client_self_id:
return
room_id = data.get("toRoomId")
if room_id:
raw_text = data.get("text", "")
logger.debug(
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
)
message = await self.convert_room_message(data)
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
else:
message = await self.convert_chat_message(data)
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
async def _debug_handler(self, data: Dict[str, Any]):
logger.debug(
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
text = note.get("text", "")
if not text:
return False
mentions = note.get("mentions", [])
if self._bot_username and f"@{self._bot_username}" in text:
return True
if self.client_self_id in [str(uid) for uid in mentions]:
return True
reply = note.get("reply")
if reply and isinstance(reply, dict):
reply_user_id = str(reply.get("user", {}).get("id", ""))
if reply_user_id == self.client_self_id:
return bool(self._bot_username and f"@{self._bot_username}" in text)
return False
async def send_by_session(
self, session: MessageSession, message_chain: MessageChain
) -> Awaitable[Any]:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
try:
session_id = session.session_id
text, has_at_user = serialize_message_chain(message_chain.chain)
if not has_at_user and session_id:
user_info = self._user_cache.get(session_id)
text = add_at_mention_if_needed(text, user_info, has_at_user)
if not text or not text.strip():
logger.warning("[Misskey] 消息内容为空,跳过发送")
return await super().send_by_session(session, message_chain)
if len(text) > self.max_message_length:
text = text[: self.max_message_length] + "..."
if session_id and is_valid_user_session_id(session_id):
from .misskey_utils import extract_user_id_from_session_id
user_id = extract_user_id_from_session_id(session_id)
await self.api.send_message(user_id, text)
elif session_id and is_valid_room_session_id(session_id):
from .misskey_utils import extract_room_id_from_session_id
room_id = extract_room_id_from_session_id(session_id)
await self.api.send_room_message(room_id, text)
else:
visibility, visible_user_ids = resolve_message_visibility(
user_id=session_id,
user_cache=self._user_cache,
self_id=self.client_self_id,
default_visibility=self.default_visibility,
)
await self.api.create_note(
text,
visibility=visibility,
visible_user_ids=visible_user_ids,
local_only=self.local_only,
)
except Exception as e:
logger.error(f"[Misskey] 发送消息失败: {e}")
return await super().send_by_session(session, message_chain)
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=False)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
message_parts = []
raw_text = raw_data.get("text", "")
if raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=True,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
)
raw_text = raw_data.get("text", "")
if raw_text:
message.message.append(Comp.Plain(raw_text))
files = raw_data.get("files", [])
process_files(message, files, include_text_parts=False)
message.message_str = raw_text if raw_text else ""
return message
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
room_id = raw_data.get("toRoomId", "")
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
room_id=room_id,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
cache_room_info(self._user_cache, raw_data, self.client_self_id)
raw_text = raw_data.get("text", "")
message_parts = []
if raw_text:
if self._bot_username and f"@{self._bot_username}" in raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def terminate(self):
self._running = False
if self.api:
await self.api.close()
def get_client(self) -> Any:
return self.api

View File

@@ -1,404 +0,0 @@
import json
from typing import Any, Optional, Dict, List, Callable, Awaitable
import uuid
try:
import aiohttp
import websockets
except ImportError as e:
raise ImportError(
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
) from e
from astrbot.api import logger
# Constants
API_MAX_RETRIES = 3
HTTP_OK = 200
class APIError(Exception):
"""Misskey API 基础异常"""
pass
class APIConnectionError(APIError):
"""网络连接异常"""
pass
class APIRateLimitError(APIError):
"""API 频率限制异常"""
pass
class AuthenticationError(APIError):
"""认证失败异常"""
pass
class WebSocketError(APIError):
"""WebSocket 连接异常"""
pass
class StreamingClient:
def __init__(self, instance_url: str, access_token: str):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self.websocket: Optional[Any] = None
self.is_connected = False
self.message_handlers: Dict[str, Callable] = {}
self.channels: Dict[str, str] = {}
self._running = False
self._last_pong = None
async def connect(self) -> bool:
try:
ws_url = self.instance_url.replace("https://", "wss://").replace(
"http://", "ws://"
)
ws_url += f"/streaming?i={self.access_token}"
self.websocket = await websockets.connect(
ws_url, ping_interval=30, ping_timeout=10
)
self.is_connected = True
self._running = True
logger.info("[Misskey WebSocket] 已连接")
return True
except Exception as e:
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
self.is_connected = False
return False
async def disconnect(self):
self._running = False
if self.websocket:
await self.websocket.close()
self.websocket = None
self.is_connected = False
logger.info("[Misskey WebSocket] 连接已断开")
async def subscribe_channel(
self, channel_type: str, params: Optional[Dict] = None
) -> str:
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
channel_id = str(uuid.uuid4())
message = {
"type": "connect",
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
}
await self.websocket.send(json.dumps(message))
self.channels[channel_id] = channel_type
return channel_id
async def unsubscribe_channel(self, channel_id: str):
if (
not self.is_connected
or not self.websocket
or channel_id not in self.channels
):
return
message = {"type": "disconnect", "body": {"id": channel_id}}
await self.websocket.send(json.dumps(message))
del self.channels[channel_id]
def add_message_handler(
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
):
self.message_handlers[event_type] = handler
async def listen(self):
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
try:
async for message in self.websocket:
if not self._running:
break
try:
data = json.loads(message)
await self._handle_message(data)
except json.JSONDecodeError as e:
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
except Exception as e:
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
self.is_connected = False
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
)
self.is_connected = False
except websockets.exceptions.InvalidHandshake as e:
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
self.is_connected = False
except Exception as e:
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
self.is_connected = False
async def _handle_message(self, data: Dict[str, Any]):
message_type = data.get("type")
body = data.get("body", {})
logger.debug(
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
)
if message_type == "channel":
channel_id = body.get("id")
event_type = body.get("type")
event_body = body.get("body", {})
logger.debug(
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
)
if channel_id in self.channels:
channel_type = self.channels[channel_id]
handler_key = f"{channel_type}:{event_type}"
if handler_key in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
await self.message_handlers[handler_key](event_body)
elif event_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
await self.message_handlers[event_type](event_body)
else:
logger.debug(
f"[Misskey WebSocket] 未找到处理器: {handler_key}{event_type}"
)
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](
{
"type": event_type,
"body": event_body,
"channel": channel_type,
}
)
elif message_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
await self.message_handlers[message_type](body)
else:
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](data)
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
def decorator(func):
async def wrapper(*args, **kwargs):
last_exc = None
for _ in range(max_retries):
try:
return await func(*args, **kwargs)
except retryable_exceptions as e:
last_exc = e
continue
if last_exc:
raise last_exc
return wrapper
return decorator
class MisskeyAPI:
def __init__(self, instance_url: str, access_token: str):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self._session: Optional[aiohttp.ClientSession] = None
self.streaming: Optional[StreamingClient] = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return False
async def close(self) -> None:
if self.streaming:
await self.streaming.disconnect()
self.streaming = None
if self._session:
await self._session.close()
self._session = None
logger.debug("[Misskey API] 客户端已关闭")
def get_streaming_client(self) -> StreamingClient:
if not self.streaming:
self.streaming = StreamingClient(self.instance_url, self.access_token)
return self.streaming
@property
def session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
headers = {"Authorization": f"Bearer {self.access_token}"}
self._session = aiohttp.ClientSession(headers=headers)
return self._session
def _handle_response_status(self, status: int, endpoint: str):
"""处理 HTTP 响应状态码"""
if status == 400:
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
raise APIError(f"Bad request for {endpoint}")
elif status in (401, 403):
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
raise AuthenticationError(f"Authentication failed for {endpoint}")
elif status == 429:
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
else:
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
raise APIConnectionError(f"HTTP {status} for {endpoint}")
async def _process_response(
self, response: aiohttp.ClientResponse, endpoint: str
) -> Any:
"""处理 API 响应"""
if response.status == HTTP_OK:
try:
result = await response.json()
if endpoint == "i/notifications":
notifications_data = (
result
if isinstance(result, list)
else result.get("notifications", [])
if isinstance(result, dict)
else []
)
if notifications_data:
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
else:
logger.debug(f"API 请求成功: {endpoint}")
return result
except json.JSONDecodeError as e:
logger.error(f"响应不是有效的 JSON 格式: {e}")
raise APIConnectionError("Invalid JSON response") from e
else:
try:
error_text = await response.text()
logger.error(
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
)
except Exception:
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
self._handle_response_status(response.status, endpoint)
raise APIConnectionError(f"Request failed for {endpoint}")
@retry_async(
max_retries=API_MAX_RETRIES,
retryable_exceptions=(APIConnectionError, APIRateLimitError),
)
async def _make_request(
self, endpoint: str, data: Optional[Dict[str, Any]] = None
) -> Any:
url = f"{self.instance_url}/api/{endpoint}"
payload = {"i": self.access_token}
if data:
payload.update(data)
try:
async with self.session.post(url, json=payload) as response:
return await self._process_response(response, endpoint)
except aiohttp.ClientError as e:
logger.error(f"HTTP 请求错误: {e}")
raise APIConnectionError(f"HTTP request failed: {e}") from e
async def create_note(
self,
text: str,
visibility: str = "public",
reply_id: Optional[str] = None,
visible_user_ids: Optional[List[str]] = None,
local_only: bool = False,
) -> Dict[str, Any]:
"""创建新贴文"""
data: Dict[str, Any] = {
"text": text,
"visibility": visibility,
"localOnly": local_only,
}
if reply_id:
data["replyId"] = reply_id
if visible_user_ids and visibility == "specified":
data["visibleUserIds"] = visible_user_ids
result = await self._make_request("notes/create", data)
note_id = result.get("createdNote", {}).get("id", "unknown")
logger.debug(f"发帖成功note_id: {note_id}")
return result
async def get_current_user(self) -> Dict[str, Any]:
"""获取当前用户信息"""
return await self._make_request("i", {})
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
"""发送聊天消息"""
result = await self._make_request(
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
)
message_id = result.get("id", "unknown")
logger.debug(f"聊天发送成功message_id: {message_id}")
return result
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
"""发送房间消息"""
result = await self._make_request(
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
)
message_id = result.get("id", "unknown")
logger.debug(f"房间消息发送成功message_id: {message_id}")
return result
async def get_messages(
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取聊天消息历史"""
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
if since_id:
data["sinceId"] = since_id
result = await self._make_request("chat/messages/user-timeline", data)
if isinstance(result, list):
return result
else:
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
return []
async def get_mentions(
self, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取提及通知"""
data: Dict[str, Any] = {"limit": limit}
if since_id:
data["sinceId"] = since_id
data["includeTypes"] = ["mention", "reply", "quote"]
result = await self._make_request("i/notifications", data)
if isinstance(result, list):
return result
elif isinstance(result, dict) and "notifications" in result:
return result["notifications"]
else:
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
return []

View File

@@ -1,123 +0,0 @@
import asyncio
import re
from typing import AsyncGenerator
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
from astrbot.api.message_components import Plain
from .misskey_utils import (
serialize_message_chain,
resolve_visibility_from_raw_message,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
extract_user_id_from_session_id,
extract_room_id_from_session_id,
)
class MisskeyPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
def _is_system_command(self, message_str: str) -> bool:
"""检测是否为系统指令"""
if not message_str or not message_str.strip():
return False
system_prefixes = ["/", "!", "#", ".", "^"]
message_trimmed = message_str.strip()
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
async def send(self, message: MessageChain):
content, has_at = serialize_message_chain(message.chain)
if not content:
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
return
try:
original_message_id = getattr(self.message_obj, "message_id", None)
raw_message = getattr(self.message_obj, "raw_message", {})
if raw_message and not has_at:
user_data = raw_message.get("user", {})
user_info = {
"username": user_data.get("username", ""),
"nickname": user_data.get("name", user_data.get("username", "")),
}
content = add_at_mention_if_needed(content, user_info, has_at)
# 根据会话类型选择发送方式
if hasattr(self.client, "send_message") and is_valid_user_session_id(
self.session_id
):
user_id = extract_user_id_from_session_id(self.session_id)
await self.client.send_message(user_id, content)
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
self.session_id
):
room_id = extract_room_id_from_session_id(self.session_id)
await self.client.send_room_message(room_id, content)
elif original_message_id and hasattr(self.client, "create_note"):
visibility, visible_user_ids = resolve_visibility_from_raw_message(
raw_message
)
await self.client.create_note(
content,
reply_id=original_message_id,
visibility=visibility,
visible_user_ids=visible_user_ids,
)
elif hasattr(self.client, "create_note"):
logger.debug("[MisskeyEvent] 创建新帖子")
await self.client.create_note(content)
await super().send(message)
except Exception as e:
logger.error(f"[MisskeyEvent] 发送失败: {e}")
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -1,327 +0,0 @@
"""Misskey 平台适配器通用工具函数"""
from typing import Dict, Any, List, Tuple, Optional, Union
import astrbot.api.message_components as Comp
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
"""将消息链序列化为文本字符串"""
text_parts = []
has_at = False
def process_component(component):
nonlocal has_at
if isinstance(component, Comp.Plain):
return component.text
elif isinstance(component, Comp.File):
file_name = getattr(component, "name", "文件")
return f"[文件: {file_name}]"
elif isinstance(component, Comp.At):
has_at = True
return f"@{component.qq}"
elif hasattr(component, "text"):
text = getattr(component, "text", "")
if "@" in text:
has_at = True
return text
else:
return str(component)
for component in chain:
if isinstance(component, Comp.Node) and component.content:
for node_comp in component.content:
result = process_component(node_comp)
if result:
text_parts.append(result)
else:
result = process_component(component)
if result:
text_parts.append(result)
return "".join(text_parts), has_at
def resolve_message_visibility(
user_id: Optional[str],
user_cache: Dict[str, Any],
self_id: Optional[str],
default_visibility: str = "public",
) -> Tuple[str, Optional[List[str]]]:
"""解析 Misskey 消息的可见性设置"""
visibility = default_visibility
visible_user_ids = None
if user_id and user_cache:
user_info = user_cache.get(user_id)
if user_info:
original_visibility = user_info.get("visibility", default_visibility)
if original_visibility == "specified":
visibility = "specified"
original_visible_users = user_info.get("visible_user_ids", [])
users_to_include = [user_id]
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
def resolve_visibility_from_raw_message(
raw_message: Dict[str, Any], self_id: Optional[str] = None
) -> Tuple[str, Optional[List[str]]]:
"""从原始消息数据中解析可见性设置"""
visibility = "public"
visible_user_ids = None
if not raw_message:
return visibility, visible_user_ids
original_visibility = raw_message.get("visibility", "public")
if original_visibility == "specified":
visibility = "specified"
original_visible_users = raw_message.get("visibleUserIds", [])
sender_id = raw_message.get("userId", "")
users_to_include = []
if sender_id:
users_to_include.append(sender_id)
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "chat"
and bool(parts[1])
and parts[1] != "unknown"
)
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "room"
and bool(parts[1])
and parts[1] != "unknown"
)
def extract_user_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取用户 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2:
return parts[1]
return session_id
def extract_room_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取房间 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2 and parts[0] == "room":
return parts[1]
return session_id
def add_at_mention_if_needed(
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
) -> str:
"""如果需要且没有@用户,则添加@用户"""
if has_at or not user_info:
return text
username = user_info.get("username")
nickname = user_info.get("nickname")
if username:
mention = f"@{username}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
elif nickname:
mention = f"@{nickname}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
return text
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
"""创建文件组件和描述文本"""
file_url = file_info.get("url", "")
file_name = file_info.get("name", "未知文件")
file_type = file_info.get("type", "")
if file_type.startswith("image/"):
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
elif file_type.startswith("audio/"):
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
elif file_type.startswith("video/"):
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
else:
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
def process_files(
message: AstrBotMessage, files: list, include_text_parts: bool = True
) -> list:
"""处理文件列表,添加到消息组件中并返回文本描述"""
file_parts = []
for file_info in files:
component, part_text = create_file_component(file_info)
message.message.append(component)
if include_text_parts:
file_parts.append(part_text)
return file_parts
def extract_sender_info(
raw_data: Dict[str, Any], is_chat: bool = False
) -> Dict[str, Any]:
"""提取发送者信息"""
if is_chat:
sender = raw_data.get("fromUser", {})
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
else:
sender = raw_data.get("user", {})
sender_id = str(sender.get("id", ""))
return {
"sender": sender,
"sender_id": sender_id,
"nickname": sender.get("name", sender.get("username", "")),
"username": sender.get("username", ""),
}
def create_base_message(
raw_data: Dict[str, Any],
sender_info: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
room_id: Optional[str] = None,
unique_session: bool = False,
) -> AstrBotMessage:
"""创建基础消息对象"""
message = AstrBotMessage()
message.raw_message = raw_data
message.message = []
message.sender = MessageMember(
user_id=sender_info["sender_id"],
nickname=sender_info["nickname"],
)
if room_id:
session_prefix = "room"
session_id = f"{session_prefix}%{room_id}"
if unique_session:
session_id += f"_{sender_info['sender_id']}"
message.type = MessageType.GROUP_MESSAGE
message.group_id = room_id
elif is_chat:
session_prefix = "chat"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
else:
session_prefix = "note"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
message.session_id = (
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
)
message.message_id = str(raw_data.get("id", ""))
message.self_id = client_self_id
return message
def process_at_mention(
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
) -> Tuple[List[str], str]:
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
message_parts = []
if not raw_text:
return message_parts, ""
if bot_username and raw_text.startswith(f"@{bot_username}"):
at_mention = f"@{bot_username}"
message.message.append(Comp.At(qq=client_self_id))
remaining_text = raw_text[len(at_mention) :].strip()
if remaining_text:
message.message.append(Comp.Plain(remaining_text))
message_parts.append(remaining_text)
return message_parts, remaining_text
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
return message_parts, raw_text
def cache_user_info(
user_cache: Dict[str, Any],
sender_info: Dict[str, Any],
raw_data: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
):
"""缓存用户信息"""
if is_chat:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": "specified",
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
}
else:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": raw_data.get("visibility", "public"),
"visible_user_ids": raw_data.get("visibleUserIds", []),
}
user_cache[sender_info["sender_id"]] = user_cache_data
def cache_room_info(
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
):
"""缓存房间信息"""
room_data = raw_data.get("toRoom")
room_id = raw_data.get("toRoomId")
if room_data and room_id:
room_cache_key = f"room:{room_id}"
user_cache[room_cache_key] = {
"room_id": room_id,
"room_name": room_data.get("name", ""),
"room_description": room_data.get("description", ""),
"owner_id": room_data.get("ownerId", ""),
"visibility": "specified",
"visible_user_ids": [client_self_id],
}

View File

@@ -94,15 +94,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text,
image_base64,
image_path,
record_file_path,
record_file_path
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
if (
not plain_text
and not image_base64
and not image_path
and not record_file_path
):
if not plain_text and not image_base64 and not image_path and not record_file_path:
return
payload = {
@@ -123,7 +118,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # group record msg
if record_file_path: # group record msg
media = await self.upload_group_and_c2c_record(
record_file_path, 3, group_openid=source.group_openid
)
@@ -139,9 +134,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # c2c record
if record_file_path: # c2c record
media = await self.upload_group_and_c2c_record(
record_file_path, 3, openid=source.author.user_openid
record_file_path, 3, openid = source.author.user_openid
)
payload["media"] = media
payload["msg_type"] = 7
@@ -195,55 +190,58 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return await self.bot.api._http.request(route, json=payload)
async def upload_group_and_c2c_record(
self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs
self,
file_source: str,
file_type: int,
srv_send_msg: bool = False,
**kwargs
) -> Optional[Media]:
"""
上传媒体文件
"""
# 构建基础payload
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
payload = {
"file_type": file_type,
"srv_send_msg": srv_send_msg
}
# 处理文件数据
if os.path.exists(file_source):
# 读取本地文件
async with aiofiles.open(file_source, "rb") as f:
async with aiofiles.open(file_source, 'rb') as f:
file_content = await f.read()
# use base64 encode
payload["file_data"] = base64.b64encode(file_content).decode("utf-8")
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
else:
# 使用URL
payload["url"] = file_source
# 添加接收者信息和确定路由
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"]
route = Route(
"POST",
"/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"],
)
payload["group_openid"] =kwargs["group_openid"]
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
else:
return None
try:
# 使用底层HTTP请求
result = await self.bot.api._http.request(route, json=payload)
if result:
return Media(
file_uuid=result.get("file_uuid"),
file_info=result.get("file_info"),
ttl=result.get("ttl", 0),
file_id=result.get("id", ""),
file_id=result.get("id", "")
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
return None
async def post_c2c_message(
self,
openid: str,
@@ -288,23 +286,19 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
record_wav_path = await i.convert_to_file_path() # wav 路径
record_wav_path = await i.convert_to_file_path() # wav 路径
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_tecent_silk_path = os.path.join(
temp_dir, f"{uuid.uuid4()}.silk"
)
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
try:
duration = await wav_to_tencent_silk(
record_wav_path, record_tecent_silk_path
)
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
if duration > 0:
record_file_path = record_tecent_silk_path
else:
record_file_path = None
record_file_path = None
logger.error("转换音频格式时出错音频时长不大于0")
except Exception as e:
logger.error(f"处理语音时出错: {e}")
record_file_path = None
record_file_path = None
else:
logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path, record_file_path

View File

@@ -1,748 +0,0 @@
import asyncio
import json
import time
import websockets
from websockets.asyncio.client import connect
from typing import Optional
from aiohttp import ClientSession, ClientTimeout
from websockets.asyncio.client import ClientConnection
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.api.message_components import (
Plain,
Image,
At,
File,
Record,
Reply,
)
from xml.etree import ElementTree as ET
@register_platform_adapter(
"satori",
"Satori 协议适配器",
)
class SatoriPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.api_base_url = self.config.get(
"satori_api_base_url", "http://localhost:5140/satori/v1"
)
self.token = self.config.get("satori_token", "")
self.endpoint = self.config.get(
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
)
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
self.metadata = PlatformMetadata(
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
)
self.ws: Optional[ClientConnection] = None
self.session: Optional[ClientSession] = None
self.sequence = 0
self.logins = []
self.running = False
self.heartbeat_task: Optional[asyncio.Task] = None
self.ready_received = False
async def send_by_session(
self, session: MessageSession, message_chain: MessageChain
):
from .satori_event import SatoriPlatformEvent
await SatoriPlatformEvent.send_with_adapter(
self, message_chain, session.session_id
)
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return self.metadata
def _is_websocket_closed(self, ws) -> bool:
"""检查WebSocket连接是否已关闭"""
if not ws:
return True
try:
if hasattr(ws, "closed"):
return ws.closed
elif hasattr(ws, "close_code"):
return ws.close_code is not None
else:
return False
except AttributeError:
return False
async def run(self):
self.running = True
self.session = ClientSession(timeout=ClientTimeout(total=30))
retry_count = 0
max_retries = 10
while self.running:
try:
await self.connect_websocket()
retry_count = 0
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Satori WebSocket 连接关闭: {e}")
retry_count += 1
except Exception as e:
logger.error(f"Satori WebSocket 连接失败: {e}")
retry_count += 1
if not self.running:
break
if retry_count >= max_retries:
logger.error(f"达到最大重试次数 ({max_retries}),停止重试")
break
if not self.auto_reconnect:
break
delay = min(self.reconnect_delay * (2 ** (retry_count - 1)), 60)
await asyncio.sleep(delay)
if self.session:
await self.session.close()
async def connect_websocket(self):
logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}")
logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}")
if not self.endpoint.startswith(("ws://", "wss://")):
logger.error(f"无效的WebSocket URL: {self.endpoint}")
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
try:
websocket = await connect(self.endpoint, additional_headers={})
self.ws = websocket
await asyncio.sleep(0.1)
await self.send_identify()
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
async for message in websocket:
try:
await self.handle_message(message) # type: ignore
except Exception as e:
logger.error(f"Satori 处理消息异常: {e}")
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Satori WebSocket 连接关闭: {e}")
raise
except Exception as e:
logger.error(f"Satori WebSocket 连接异常: {e}")
raise
finally:
if self.heartbeat_task:
self.heartbeat_task.cancel()
try:
await self.heartbeat_task
except asyncio.CancelledError:
pass
if self.ws:
try:
await self.ws.close()
except Exception as e:
logger.error(f"Satori WebSocket 关闭异常: {e}")
async def send_identify(self):
if not self.ws:
raise Exception("WebSocket连接未建立")
if self._is_websocket_closed(self.ws):
raise Exception("WebSocket连接已关闭")
identify_payload = {
"op": 3, # IDENTIFY
"body": {
"token": str(self.token) if self.token else "", # 字符串
},
}
# 只有在有序列号时才添加sn字段
if self.sequence > 0:
identify_payload["body"]["sn"] = self.sequence
try:
message_str = json.dumps(identify_payload, ensure_ascii=False)
await self.ws.send(message_str)
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"发送 IDENTIFY 信令时连接关闭: {e}")
raise
except Exception as e:
logger.error(f"发送 IDENTIFY 信令失败: {e}")
raise
async def heartbeat_loop(self):
try:
while self.running and self.ws:
await asyncio.sleep(self.heartbeat_interval)
if self.ws and not self._is_websocket_closed(self.ws):
try:
ping_payload = {
"op": 1, # PING
"body": {},
}
await self.ws.send(json.dumps(ping_payload, ensure_ascii=False))
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"Satori WebSocket 连接关闭: {e}")
break
except Exception as e:
logger.error(f"Satori WebSocket 发送心跳失败: {e}")
break
else:
break
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"心跳任务异常: {e}")
async def handle_message(self, message: str):
try:
data = json.loads(message)
op = data.get("op")
body = data.get("body", {})
if op == 4: # READY
self.logins = body.get("logins", [])
self.ready_received = True
# 输出连接成功的bot信息
if self.logins:
for i, login in enumerate(self.logins):
platform = login.get("platform", "")
user = login.get("user", {})
user_id = user.get("id", "")
user_name = user.get("name", "")
logger.info(
f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}"
)
if "sn" in body:
self.sequence = body["sn"]
elif op == 2: # PONG
pass
elif op == 0: # EVENT
await self.handle_event(body)
if "sn" in body:
self.sequence = body["sn"]
elif op == 5: # META
if "sn" in body:
self.sequence = body["sn"]
except json.JSONDecodeError as e:
logger.error(f"解析 WebSocket 消息失败: {e}, 消息内容: {message}")
except Exception as e:
logger.error(f"处理 WebSocket 消息异常: {e}")
async def handle_event(self, event_data: dict):
try:
event_type = event_data.get("type")
sn = event_data.get("sn")
if sn:
self.sequence = sn
if event_type == "message-created":
message = event_data.get("message", {})
user = event_data.get("user", {})
channel = event_data.get("channel", {})
guild = event_data.get("guild")
login = event_data.get("login", {})
timestamp = event_data.get("timestamp")
if user.get("id") == login.get("user", {}).get("id"):
return
abm = await self.convert_satori_message(
message, user, channel, guild, login, timestamp
)
if abm:
await self.handle_msg(abm)
except Exception as e:
logger.error(f"处理事件失败: {e}")
async def convert_satori_message(
self,
message: dict,
user: dict,
channel: dict,
guild: Optional[dict],
login: dict,
timestamp: Optional[int] = None,
) -> Optional[AstrBotMessage]:
try:
abm = AstrBotMessage()
abm.message_id = message.get("id", "")
abm.raw_message = {
"message": message,
"user": user,
"channel": channel,
"guild": guild,
"login": login,
}
if guild and guild.get("id"):
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = guild.get("id", "")
abm.session_id = channel.get("id", "")
else:
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = channel.get("id", "")
abm.sender = MessageMember(
user_id=user.get("id", ""),
nickname=user.get("nick", user.get("name", "")),
)
abm.self_id = login.get("user", {}).get("id", "")
# 消息链
abm.message = []
content = message.get("content", "")
quote = message.get("quote")
content_for_parsing = content # 副本
# 提取<quote>标签
if "<quote" in content:
try:
quote_info = await self._extract_quote_element(content)
if quote_info:
quote = quote_info["quote"]
content_for_parsing = quote_info["content_without_quote"]
except Exception as e:
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
if quote:
# 引用消息
quote_abm = await self._convert_quote_message(quote)
if quote_abm:
sender_id = quote_abm.sender.user_id
if isinstance(sender_id, str) and sender_id.isdigit():
sender_id = int(sender_id)
elif not isinstance(sender_id, int):
sender_id = 0 # 默认值
reply_component = Reply(
id=quote_abm.message_id,
chain=quote_abm.message,
sender_id=quote_abm.sender.user_id,
sender_nickname=quote_abm.sender.nickname,
time=quote_abm.timestamp,
message_str=quote_abm.message_str,
text=quote_abm.message_str,
qq=sender_id,
)
abm.message.append(reply_component)
# 解析消息内容
content_elements = await self.parse_satori_elements(content_for_parsing)
abm.message.extend(content_elements)
abm.message_str = ""
for comp in content_elements:
if isinstance(comp, Plain):
abm.message_str += comp.text
# 优先使用Satori事件中的时间戳
if timestamp is not None:
abm.timestamp = timestamp
else:
abm.timestamp = int(time.time())
return abm
except Exception as e:
logger.error(f"转换 Satori 消息失败: {e}")
return None
def _extract_namespace_prefixes(self, content: str) -> set:
"""提取XML内容中的命名空间前缀"""
prefixes = set()
# 查找所有标签
i = 0
while i < len(content):
# 查找开始标签
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 1 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content and "xmlns:" not in tag_content:
# 分割标签名
parts = tag_content.split()
if parts:
tag_name = parts[0]
if ":" in tag_name:
prefix = tag_name.split(":")[0]
# 确保是有效的命名空间前缀
if (
prefix.isalnum()
or prefix.replace("_", "").isalnum()
):
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
# 查找结束标签
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 2 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content:
prefix = tag_content.split(":")[0]
# 确保是有效的命名空间前缀
if prefix.isalnum() or prefix.replace("_", "").isalnum():
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
else:
i += 1
return prefixes
async def _extract_quote_element(self, content: str) -> Optional[dict]:
"""提取<quote>标签信息"""
try:
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
# 查找<quote>标签
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
if tag_name.lower() == "quote":
quote_element = elem
break
if quote_element is not None:
# 提取quote标签的属性
quote_id = quote_element.get("id", "")
# 提取<quote>标签内部的内容
inner_content = ""
if quote_element.text:
inner_content += quote_element.text
for child in quote_element:
inner_content += ET.tostring(
child, encoding="unicode", method="xml"
)
if child.tail:
inner_content += child.tail
# 构造移除了<quote>标签的内容
content_without_quote = content.replace(
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
)
return {
"quote": {"id": quote_id, "content": inner_content},
"content_without_quote": content_without_quote,
}
return None
except Exception as e:
logger.error(f"提取<quote>标签时发生错误: {e}")
return None
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
"""转换引用消息"""
try:
quote_abm = AstrBotMessage()
quote_abm.message_id = quote.get("id", "")
# 解析引用消息的发送者
quote_author = quote.get("author", {})
if quote_author:
quote_abm.sender = MessageMember(
user_id=quote_author.get("id", ""),
nickname=quote_author.get("nick", quote_author.get("name", "")),
)
else:
# 如果没有作者信息,使用默认值
quote_abm.sender = MessageMember(
user_id=quote.get("user_id", ""),
nickname="内容",
)
# 解析引用消息内容
quote_content = quote.get("content", "")
quote_abm.message = await self.parse_satori_elements(quote_content)
quote_abm.message_str = ""
for comp in quote_abm.message:
if isinstance(comp, Plain):
quote_abm.message_str += comp.text
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
# 如果没有任何内容,使用默认文本
if not quote_abm.message_str.strip():
quote_abm.message_str = "[引用消息]"
return quote_abm
except Exception as e:
logger.error(f"转换引用消息失败: {e}")
return None
async def parse_satori_elements(self, content: str) -> list:
"""解析 Satori 消息元素"""
elements = []
if not content:
return elements
try:
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
# 如果解析失败,将整个内容当作纯文本
if content.strip():
elements.append(Plain(text=content))
except Exception as e:
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
raise e
# 如果没有解析到任何元素,将整个内容当作纯文本
if not elements and content.strip():
elements.append(Plain(text=content))
return elements
async def _parse_xml_node(self, node: ET.Element, elements: list) -> None:
"""递归解析 XML 节点"""
if node.text and node.text.strip():
elements.append(Plain(text=node.text))
for child in node:
# 获取标签名,去除命名空间前缀
tag_name = child.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
tag_name = tag_name.lower()
attrs = child.attrib
if tag_name == "at":
user_id = attrs.get("id") or attrs.get("name", "")
elements.append(At(qq=user_id, name=user_id))
elif tag_name in ("img", "image"):
src = attrs.get("src", "")
if not src:
continue
elements.append(Image(file=src))
elif tag_name == "file":
src = attrs.get("src", "")
name = attrs.get("name", "文件")
if src:
elements.append(File(name=name, file=src))
elif tag_name in ("audio", "record"):
src = attrs.get("src", "")
if not src:
continue
elements.append(Record(file=src))
elif tag_name == "quote":
# quote标签已经被特殊处理
pass
elif tag_name == "face":
face_id = attrs.get("id", "")
face_name = attrs.get("name", "")
face_type = attrs.get("type", "")
if face_name:
elements.append(Plain(text=f"[表情:{face_name}]"))
elif face_id and face_type:
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
elif face_id:
elements.append(Plain(text=f"[表情ID:{face_id}]"))
else:
elements.append(Plain(text="[表情]"))
elif tag_name == "ark":
# 作为纯文本添加到消息链中
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[ARK卡片]"))
elif tag_name == "json":
# JSON标签 视为ARK卡片消息
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[JSON卡片]"))
else:
# 未知标签,递归处理其内容
if child.text and child.text.strip():
elements.append(Plain(text=child.text))
await self._parse_xml_node(child, elements)
# 处理标签后的文本
if child.tail and child.tail.strip():
elements.append(Plain(text=child.tail))
async def handle_msg(self, message: AstrBotMessage):
from .satori_event import SatoriPlatformEvent
message_event = SatoriPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
adapter=self,
)
self.commit_event(message_event)
async def send_http_request(
self,
method: str,
path: str,
data: dict | None = None,
platform: str | None = None,
user_id: str | None = None,
) -> dict:
if not self.session:
raise Exception("HTTP session 未初始化")
headers = {
"Content-Type": "application/json",
}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
if platform and user_id:
headers["satori-platform"] = platform
headers["satori-user-id"] = user_id
elif self.logins:
current_login = self.logins[0]
headers["satori-platform"] = current_login.get("platform", "")
user = current_login.get("user", {})
headers["satori-user-id"] = user.get("id", "") if user else ""
if not path.startswith("/"):
path = "/" + path
# 使用新的API地址配置
url = f"{self.api_base_url.rstrip('/')}{path}"
try:
async with self.session.request(
method, url, json=data, headers=headers
) as response:
if response.status == 200:
result = await response.json()
return result
else:
return {}
except Exception as e:
logger.error(f"Satori HTTP 请求异常: {e}")
return {}
async def terminate(self):
self.running = False
if self.heartbeat_task:
self.heartbeat_task.cancel()
if self.ws:
try:
await self.ws.close()
except Exception as e:
logger.error(f"Satori WebSocket 关闭异常: {e}")
if self.session:
await self.session.close()

View File

@@ -1,230 +0,0 @@
from typing import TYPE_CHECKING
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, At, File, Record
if TYPE_CHECKING:
from .satori_adapter import SatoriPlatformAdapter
class SatoriPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
adapter: "SatoriPlatformAdapter",
):
# 更新平台元数据
if adapter and hasattr(adapter, "logins") and adapter.logins:
current_login = adapter.logins[0]
platform_name = current_login.get("platform", "satori")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
if not platform_meta.id and user_id:
platform_meta.id = f"{platform_name}({user_id})"
super().__init__(message_str, message_obj, platform_meta, session_id)
self.adapter = adapter
self.platform = None
self.user_id = None
if (
hasattr(message_obj, "raw_message")
and message_obj.raw_message
and isinstance(message_obj.raw_message, dict)
):
login = message_obj.raw_message.get("login", {})
self.platform = login.get("platform")
user = login.get("user", {})
self.user_id = user.get("id") if user else None
@classmethod
async def send_with_adapter(
cls, adapter: "SatoriPlatformAdapter", message: MessageChain, session_id: str
):
try:
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
content = "".join(content_parts)
channel_id = session_id
data = {"channel_id": channel_id, "content": content}
platform = None
user_id = None
if hasattr(adapter, "logins") and adapter.logins:
current_login = adapter.logins[0]
platform = current_login.get("platform", "")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
result = await adapter.send_http_request(
"POST", "/message.create", data, platform, user_id
)
if result:
return result
else:
return None
except Exception as e:
logger.error(f"Satori 消息发送异常: {e}")
return None
async def send(self, message: MessageChain):
platform = getattr(self, "platform", None)
user_id = getattr(self, "user_id", None)
if not platform or not user_id:
if hasattr(self.adapter, "logins") and self.adapter.logins:
current_login = self.adapter.logins[0]
platform = current_login.get("platform", "")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
try:
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
content = "".join(content_parts)
channel_id = self.session_id
data = {"channel_id": channel_id, "content": content}
result = await self.adapter.send_http_request(
"POST", "/message.create", data, platform, user_id
)
if not result:
logger.error("Satori 消息发送失败")
except Exception as e:
logger.error(f"Satori 消息发送异常: {e}")
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
try:
content_parts = []
async for chain in generator:
if isinstance(chain, MessageChain):
if chain.type == "break":
if content_parts:
content = "".join(content_parts)
temp_chain = MessageChain([Plain(text=content)])
await self.send(temp_chain)
content_parts = []
continue
for component in chain.chain:
if isinstance(component, Plain):
content_parts.append(component.text)
elif isinstance(component, Image):
if content_parts:
content = "".join(content_parts)
temp_chain = MessageChain([Plain(text=content)])
await self.send(temp_chain)
content_parts = []
try:
image_base64 = await component.convert_to_base64()
if image_base64:
img_chain = MessageChain(
[
Plain(
text=f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
]
)
await self.send(img_chain)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
else:
content_parts.append(str(component))
if content_parts:
content = "".join(content_parts)
temp_chain = MessageChain([Plain(text=content)])
await self.send(temp_chain)
except Exception as e:
logger.error(f"Satori 流式消息发送异常: {e}")
return await super().send_streaming(generator, use_fallback)

View File

@@ -308,9 +308,7 @@ class SlackAdapter(Platform):
base64_content = base64.b64encode(content).decode("utf-8")
return base64_content
else:
logger.error(
f"Failed to download slack file: {resp.status} {await resp.text()}"
)
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:

View File

@@ -75,13 +75,7 @@ class SlackMessageEvent(AstrMessageEvent):
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
return {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
else:
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}

View File

@@ -183,6 +183,7 @@ class TelegramPlatformAdapter(Platform):
return None
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
logger.debug(f"跳过无法注册的命令: {cmd_name}")
return None
# Build description.

View File

@@ -66,9 +66,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
return chunks
@classmethod
async def send_with_client(
cls, client: ExtBot, message: MessageChain, user_name: str
):
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
image_path = None
has_reply = False
@@ -218,6 +216,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
delta = ""
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id

View File

@@ -1,6 +1,5 @@
import asyncio
class WebChatQueueMgr:
def __init__(self) -> None:
self.queues = {}
@@ -31,5 +30,4 @@ class WebChatQueueMgr:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
webchat_queue_mgr = WebChatQueueMgr()

View File

@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
def _extract_auth_key(self, data):
"""Helper method to extract auth_key from response data."""
if isinstance(data, dict):
auth_keys = data.get("authKeys") # 新接口
auth_keys = data.get("authKeys") # 新接口
if isinstance(auth_keys, list) and auth_keys:
return auth_keys[0]
elif isinstance(data, list) and data: # 旧接口
elif isinstance(data, list) and data: # 旧接口
return data[0]
return None
@@ -234,9 +234,7 @@ class WeChatPadProAdapter(Platform):
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(
f"生成授权码失败: {response.status}, {await response.text()}"
)
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
return
response_data = await response.json()
@@ -247,9 +245,7 @@ class WeChatPadProAdapter(Platform):
if self.auth_key:
logger.info("成功获取授权码")
else:
logger.error(
f"生成授权码成功但未找到授权码: {response_data}"
)
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
else:
logger.error(f"生成授权码失败: {response_data}")
except aiohttp.ClientConnectorError as e:

View File

@@ -185,7 +185,6 @@ class WecomPlatformAdapter(Platform):
return PlatformMetadata(
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
)
@override

View File

@@ -48,12 +48,7 @@ class WeChatKF(BaseWeChatAPI):
注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。
:return: 接口调用结果
"""
data = {
"token": token,
"cursor": cursor,
"limit": limit,
"open_kfid": open_kfid,
}
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
return self._post("kf/sync_msg", data=data)
def get_service_state(self, open_kfid, external_userid):
@@ -77,9 +72,7 @@ class WeChatKF(BaseWeChatAPI):
}
return self._post("kf/service_state/get", data=data)
def trans_service_state(
self, open_kfid, external_userid, service_state, servicer_userid=""
):
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
"""
变更会话状态
@@ -187,9 +180,7 @@ class WeChatKF(BaseWeChatAPI):
"""
return self._get("kf/customer/get_upgrade_service_config")
def upgrade_service(
self, open_kfid, external_userid, service_type, member=None, groupchat=None
):
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
"""
为客户升级为专员或客户群服务
@@ -255,9 +246,7 @@ class WeChatKF(BaseWeChatAPI):
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
return self._post("kf/get_corp_statistic", data=data)
def get_servicer_statistic(
self, start_time, end_time, open_kfid=None, servicer_userid=None
):
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
"""
获取「客户数据统计」接待人员明细数据

View File

@@ -26,7 +26,6 @@ from optionaldict import optionaldict
from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKFMessage(BaseWeChatAPI):
"""
发送微信客服消息
@@ -126,55 +125,35 @@ class WeChatKFMessage(BaseWeChatAPI):
msg={"msgtype": "news", "link": {"link": articles_data}},
)
def send_msgmenu(
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
):
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "msgmenu",
"msgmenu": {
"head_content": head_content,
"list": menu_list,
"tail_content": tail_content,
},
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
},
)
def send_location(
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
):
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "location",
"msgmenu": {
"name": name,
"address": address,
"latitude": latitude,
"longitude": longitude,
},
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
},
)
def send_miniprogram(
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
):
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "miniprogram",
"msgmenu": {
"appid": appid,
"title": title,
"thumb_media_id": thumb_media_id,
"pagepath": pagepath,
},
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
},
)

View File

@@ -160,9 +160,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(
asyncio.shield(future), 60
) # wait for 60s
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py
@@ -184,7 +182,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
return PlatformMetadata(
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
)
@override

View File

@@ -150,6 +150,7 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
return
logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,

View File

@@ -4,11 +4,9 @@ import json
from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type, Any
from typing import List, Dict, Type
from astrbot.core.agent.tool import ToolSet
from openai.types.chat.chat_completion import ChatCompletion
from google.genai.types import GenerateContentResponse
from anthropic.types import Message
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
@@ -32,11 +30,11 @@ class ProviderMetaData:
desc: str = ""
"""提供商适配器描述."""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type | None = None
cls_type: Type = None
default_config_tmpl: dict | None = None
default_config_tmpl: dict = None
"""平台的默认配置模板"""
provider_display_name: str | None = None
provider_display_name: str = None
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@@ -60,21 +58,18 @@ class ToolCallMessageSegment:
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str | None = None
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
role: str = "assistant"
def to_dict(self):
ret: dict[str, str | list[dict]] = {
ret = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
if self.tool_calls:
tool_calls_dict = [
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
]
ret["tool_calls"] = tool_calls_dict
ret["tool_calls"] = self.tool_calls
return ret
@@ -120,14 +115,7 @@ class ProviderRequest:
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self):
return (
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
f"image_count={len(self.image_urls or [])}, "
f"func_tool={self.func_tool}, "
f"contexts={self._print_friendly_context()}, "
f"system_prompt={self.system_prompt}, "
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
)
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self):
return self.__repr__()
@@ -217,17 +205,17 @@ class ProviderRequest:
class LLMResponse:
role: str
"""角色, assistant, tool, err"""
result_chain: MessageChain | None = None
result_chain: MessageChain = None
"""返回的消息链"""
tools_call_args: List[Dict[str, Any]] = field(default_factory=list)
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
tools_call_ids: List[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None
_new_record: Dict[str, Any] | None = None
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
_completion_text: str = ""
@@ -238,12 +226,12 @@ class LLMResponse:
self,
role: str,
completion_text: str = "",
result_chain: MessageChain | None = None,
tools_call_args: List[Dict[str, Any]] | None = None,
tools_call_name: List[str] | None = None,
tools_call_ids: List[str] | None = None,
raw_completion: ChatCompletion | None = None,
_new_record: Dict[str, Any] | None = None,
result_chain: MessageChain = None,
tools_call_args: List[Dict[str, any]] = None,
tools_call_name: List[str] = None,
tools_call_ids: List[str] = None,
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
@@ -307,7 +295,6 @@ class LLMResponse:
)
return ret
@dataclass
class RerankResult:
index: int

View File

@@ -4,7 +4,7 @@ import os
import asyncio
import aiohttp
from typing import Dict, List, Awaitable, Callable, Any
from typing import Dict, List, Awaitable
from astrbot import logger
from astrbot.core import sp
@@ -109,7 +109,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Awaitable,
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -132,7 +132,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Awaitable,
) -> None:
"""添加函数调用工具
@@ -220,7 +220,7 @@ class FunctionToolManager:
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future | None = None,
ready_future: asyncio.Future = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:

View File

@@ -7,13 +7,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import (
Provider,
STTProvider,
TTSProvider,
EmbeddingProvider,
RerankProvider,
)
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
from ..persona_mgr import PersonaManager
@@ -44,12 +38,7 @@ class ProviderManager:
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.rerank_provider_insts: List[RerankProvider] = []
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
] = {}
self.inst_map: dict[str, Provider] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -98,31 +87,19 @@ class ProviderManager:
)
return
# 不启用提供商会话隔离模式的情况
prov = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
prov, TTSProvider
):
self.curr_tts_provider_inst = prov
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov, STTProvider
):
self.curr_stt_provider_inst = prov
elif provider_type == ProviderType.SPEECH_TO_TEXT:
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov, Provider
):
self.curr_provider_inst = prov
elif provider_type == ProviderType.CHAT_COMPLETION:
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(
self, provider_type: ProviderType, umo=None
) -> Provider | STTProvider | TTSProvider | None:
def get_using_provider(self, provider_type: ProviderType, umo=None):
"""获取正在使用的提供商实例。
Args:
@@ -234,8 +211,6 @@ class ProviderManager:
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
@@ -328,14 +303,12 @@ class ProviderManager:
provider_metadata = provider_cls_map[provider_config["type"]]
try:
# 按任务实例化提供商
cls_type = provider_metadata.cls_type
if not cls_type:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = cls_type(provider_config, self.provider_settings)
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -354,7 +327,9 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = cls_type(provider_config, self.provider_settings)
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -370,7 +345,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = cls_type(
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
@@ -391,16 +366,13 @@ class ProviderManager:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
@@ -416,7 +388,6 @@ class ProviderManager:
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.providers_config]
logger.debug(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
@@ -455,17 +426,11 @@ class ProviderManager:
)
if self.inst_map[provider_id] in self.provider_insts:
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, Provider):
self.provider_insts.remove(prov_inst)
self.provider_insts.remove(self.inst_map[provider_id])
if self.inst_map[provider_id] in self.stt_provider_insts:
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, STTProvider):
self.stt_provider_insts.remove(prov_inst)
self.stt_provider_insts.remove(self.inst_map[provider_id])
if self.inst_map[provider_id] in self.tts_provider_insts:
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, TTSProvider):
self.tts_provider_insts.remove(prov_inst)
self.tts_provider_insts.remove(self.inst_map[provider_id])
if self.inst_map[provider_id] == self.curr_provider_inst:
self.curr_provider_inst = None

View File

@@ -1,314 +0,0 @@
import json
import asyncio
import aiohttp
import io
from typing import Dict, List, Any, AsyncGenerator
from astrbot.core import logger
class CozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
self.api_key = api_key
self.api_base = api_base
self.session = None
async def _ensure_session(self):
"""确保HTTP session存在"""
if self.session is None:
connector = aiohttp.TCPConnector(
ssl=False if self.api_base.startswith("http://") else True,
limit=100,
limit_per_host=30,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
timeout = aiohttp.ClientTimeout(
total=120, # 默认超时时间
connect=30,
sock_read=120,
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "text/event-stream",
}
self.session = aiohttp.ClientSession(
headers=headers, timeout=timeout, connector=connector
)
return self.session
async def upload_file(
self,
file_data: bytes,
) -> str:
"""上传文件到 Coze 并返回 file_id
Args:
file_data (bytes): 文件的二进制数据
Returns:
str: 上传成功后返回的 file_id
"""
session = await self._ensure_session()
url = f"{self.api_base}/v1/files/upload"
try:
file_io = io.BytesIO(file_data)
async with session.post(
url,
data={
"file": file_io,
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
response_text = await response.text()
logger.debug(
f"文件上传响应状态: {response.status}, 内容: {response_text}"
)
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
)
try:
result = await response.json()
except json.JSONDecodeError:
raise Exception(f"文件上传响应解析失败: {response_text}")
if result.get("code") != 0:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
logger.debug(f"[Coze] 图片上传成功file_id: {file_id}")
return file_id
except asyncio.TimeoutError:
logger.error("文件上传超时")
raise Exception("文件上传超时")
except Exception as e:
logger.error(f"文件上传失败: {str(e)}")
raise Exception(f"文件上传失败: {str(e)}")
async def download_image(self, image_url: str) -> bytes:
"""下载图片并返回字节数据
Args:
image_url (str): 图片的URL
Returns:
bytes: 图片的二进制数据
"""
session = await self._ensure_session()
try:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"下载图片失败,状态码: {response.status}")
image_data = await response.read()
return image_data
except Exception as e:
logger.error(f"下载图片失败 {image_url}: {str(e)}")
raise Exception(f"下载图片失败: {str(e)}")
async def chat_messages(
self,
bot_id: str,
user_id: str,
additional_messages: List[Dict] | None = None,
conversation_id: str | None = None,
auto_save_history: bool = True,
stream: bool = True,
timeout: float = 120,
) -> AsyncGenerator[Dict[str, Any], None]:
"""发送聊天消息并返回流式响应
Args:
bot_id: Bot ID
user_id: 用户ID
additional_messages: 额外消息列表
conversation_id: 会话ID
auto_save_history: 是否自动保存历史
stream: 是否流式响应
timeout: 超时时间
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/chat"
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
"auto_save_history": auto_save_history,
}
if additional_messages:
payload["additional_messages"] = additional_messages
params = {}
if conversation_id:
params["conversation_id"] = conversation_id
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
try:
async with session.post(
url,
json=payload,
params=params,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
# SSE
buffer = ""
event_type = None
event_data = None
async for chunk in response.content:
if chunk:
buffer += chunk.decode("utf-8", errors="ignore")
lines = buffer.split("\n")
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
if event_type and event_data:
yield {"event": event_type, "data": event_data}
event_type = None
event_data = None
elif line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_str = line[5:].strip()
if data_str and data_str != "[DONE]":
try:
event_data = json.loads(data_str)
except json.JSONDecodeError:
event_data = {"content": data_str}
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {str(e)}")
async def clear_context(self, conversation_id: str):
"""清空会话上下文
Args:
conversation_id: 会话ID
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/clear_context"
payload = {"conversation_id": conversation_id}
try:
async with session.post(url, json=payload) as response:
response_text = await response.text()
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
try:
return json.loads(response_text)
except json.JSONDecodeError:
raise Exception("Coze API 返回非JSON格式")
except asyncio.TimeoutError:
raise Exception("Coze API 请求超时")
except aiohttp.ClientError as e:
raise Exception(f"Coze API 请求失败: {str(e)}")
async def get_message_list(
self,
conversation_id: str,
order: str = "desc",
limit: int = 10,
offset: int = 0,
):
"""获取消息列表
Args:
conversation_id: 会话ID
order: 排序方式 (asc/desc)
limit: 限制数量
offset: 偏移量
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/list"
params = {
"conversation_id": conversation_id,
"order": order,
"limit": limit,
"offset": offset,
}
try:
async with session.get(url, params=params) as response:
response.raise_for_status()
return await response.json()
except Exception as e:
logger.error(f"获取Coze消息列表失败: {str(e)}")
raise Exception(f"获取Coze消息列表失败: {str(e)}")
async def close(self):
"""关闭会话"""
if self.session:
await self.session.close()
self.session = None
if __name__ == "__main__":
import os
import asyncio
async def test_coze_api_client():
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
try:
with open("README.md", "rb") as f:
file_data = f.read()
file_id = await client.upload_file(file_data)
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
bot_id=bot_id,
user_id="test_user",
additional_messages=[
{
"role": "user",
"content": json.dumps(
[
{"type": "text", "text": "这是什么"},
{"type": "file", "file_id": file_id},
],
ensure_ascii=False,
),
"content_type": "object_string",
},
],
stream=True,
):
print(f"Event: {event}")
finally:
await client.close()
asyncio.run(test_coze_api_client())

View File

@@ -1,635 +0,0 @@
import json
import os
import base64
import hashlib
from typing import AsyncGenerator, Dict
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://")
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: Dict[str, str] = {}
self.file_id_cache: Dict[str, Dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
else:
if data.startswith(("http://", "https://")):
# URL图片使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
else:
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self, image_url: str, session_id: str | None = None
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {str(e)}")
raise Exception(f"处理图片失败: {str(e)}")
async def _process_context_images(
self, content: str | list, session_id: str
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data, is_base64=image_data.startswith("data:image/")
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id}
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data, session_id
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}..."
)
continue
processed_content.append(
{"type": "image", "file_id": file_id}
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {str(e)}")
if isinstance(content, str):
return content
else:
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
else:
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{"role": "system", "content": system_prompt, "content_type": "text"}
)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content, user_id
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
}
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url, user_id
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url, is_base64=True
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
# 本地文件
if os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {str(e)}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
else:
# 纯文本
if prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
}
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {str(e)}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {str(e)}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
else:
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {str(e)}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self, session_id: str, page: int = 1, page_size: int = 10
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()

View File

@@ -98,7 +98,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
# FishAudio的reference_id通常是32位十六进制字符串
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
pattern = r"^[a-fA-F0-9]{32}$"
pattern = r'^[a-fA-F0-9]{32}$'
return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict:

View File

@@ -15,7 +15,7 @@ from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
@@ -61,7 +61,7 @@ class ProviderGoogleGenAI(Provider):
default_persona,
)
self.api_keys: list = provider_config.get("key", [])
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout: int = int(provider_config.get("timeout", 180))
self.api_base: Optional[str] = provider_config.get("api_base", None)
@@ -96,9 +96,6 @@ class ProviderGoogleGenAI(Provider):
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
"""处理API错误返回是否需要重试"""
if e.message is None:
e.message = ""
if e.code == 429 or "API key not valid" in e.message:
keys.remove(self.chosen_api_key)
if len(keys) > 0:
@@ -122,7 +119,7 @@ class ProviderGoogleGenAI(Provider):
async def _prepare_query_config(
self,
payloads: dict,
tools: Optional[ToolSet] = None,
tools: Optional[FuncCall] = None,
system_instruction: Optional[str] = None,
modalities: Optional[list[str]] = None,
temperature: float = 0.7,
@@ -324,15 +321,11 @@ class ProviderGoogleGenAI(Provider):
@staticmethod
def _process_content_parts(
candidate: types.Candidate, llm_response: LLMResponse
result: types.GenerateContentResponse, llm_response: LLMResponse
) -> MessageChain:
"""处理内容部分并构建消息链"""
if not candidate.content:
logger.warning(f"收到的 candidate.content 为空: {candidate}")
raise Exception("API 返回的 candidate.content 为空。")
finish_reason = candidate.finish_reason
result_parts: list[types.Part] | None = candidate.content.parts
finish_reason = result.candidates[0].finish_reason
result_parts: Optional[types.Part] = result.candidates[0].content.parts
if finish_reason == types.FinishReason.SAFETY:
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
@@ -350,28 +343,22 @@ class ProviderGoogleGenAI(Provider):
raise Exception("模型生成内容违反 Gemini 平台政策")
if not result_parts:
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
raise Exception("API 返回的 candidate.content.parts 为空。")
logger.debug(result.candidates)
raise Exception("API 返回的内容为空。")
chain = []
part: types.Part
# 暂时这样Fallback
if all(
part.inline_data
and part.inline_data.mime_type
and part.inline_data.mime_type.startswith("image/")
part.inline_data and part.inline_data.mime_type.startswith("image/")
for part in result_parts
):
chain.append(Comp.Plain("这是图片"))
for part in result_parts:
if part.text:
chain.append(Comp.Plain(part.text))
elif (
part.function_call
and part.function_call.name is not None
and part.function_call.args is not None
):
elif part.function_call:
llm_response.role = "tool"
llm_response.tools_call_name.append(part.function_call.name)
llm_response.tools_call_args.append(part.function_call.args)
@@ -379,16 +366,11 @@ class ProviderGoogleGenAI(Provider):
llm_response.tools_call_ids.append(
part.function_call.id or part.function_call.name
)
elif (
part.inline_data
and part.inline_data.mime_type
and part.inline_data.mime_type.startswith("image/")
and part.inline_data.data
):
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
chain.append(Comp.Image.fromBytes(part.inline_data.data))
return MessageChain(chain=chain)
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
@@ -414,10 +396,6 @@ class ProviderGoogleGenAI(Provider):
config=config,
)
if not result.candidates:
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
raise Exception("请求失败, 返回的 candidates 为空。")
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
if temperature > 2:
raise Exception("温度参数已超过最大值2仍然发生recitation")
@@ -430,8 +408,6 @@ class ProviderGoogleGenAI(Provider):
break
except APIError as e:
if e.message is None:
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
@@ -456,13 +432,11 @@ class ProviderGoogleGenAI(Provider):
llm_response = LLMResponse("assistant")
llm_response.raw_completion = result
llm_response.result_chain = self._process_content_parts(
result.candidates[0], llm_response
)
llm_response.result_chain = self._process_content_parts(result, llm_response)
return llm_response
async def _query_stream(
self, payloads: dict, tools: ToolSet | None
self, payloads: dict, tools: FuncCall
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
@@ -485,8 +459,6 @@ class ProviderGoogleGenAI(Provider):
)
break
except APIError as e:
if e.message is None:
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
@@ -506,20 +478,13 @@ class ProviderGoogleGenAI(Provider):
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
if not chunk.candidates:
logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}")
continue
if not chunk.candidates[0].content:
logger.warning(f"收到的 chunk 中 content 为空: {chunk}")
continue
if chunk.candidates[0].content.parts and any(
part.function_call for part in chunk.candidates[0].content.parts
):
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.raw_completion = chunk
llm_response.result_chain = self._process_content_parts(
chunk.candidates[0], llm_response
chunk, llm_response
)
yield llm_response
return
@@ -535,7 +500,7 @@ class ProviderGoogleGenAI(Provider):
final_response = LLMResponse("assistant", is_chunk=False)
final_response.raw_completion = chunk
final_response.result_chain = self._process_content_parts(
chunk.candidates[0], final_response
chunk, final_response
)
break
@@ -601,8 +566,6 @@ class ProviderGoogleGenAI(Provider):
continue
break
raise Exception("请求失败。")
async def text_chat_stream(
self,
prompt,
@@ -658,9 +621,7 @@ class ProviderGoogleGenAI(Provider):
return [
m.name.replace("models/", "")
for m in models
if m.supported_actions
and "generateContent" in m.supported_actions
and m.name
if "generateContent" in m.supported_actions
]
except APIError as e:
raise Exception(f"获取模型列表失败: {e.message}")
@@ -675,7 +636,7 @@ class ProviderGoogleGenAI(Provider):
self.chosen_api_key = key
self._init_client()
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
async def assemble_context(self, text: str, image_urls: list[str] = None):
"""
组装上下文。
"""

View File

@@ -99,15 +99,12 @@ class ProviderOpenAIOfficial(Provider):
for key in to_del:
del payloads[key]
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
model = payloads.get("model", "").lower()
model = payloads.get("model", "")
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
if "qwen3" in model.lower():
extra_body["enable_thinking"] = False
# 针对 deepseek 模型的特殊处理deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
if model == "deepseek-reasoner" and "tools" in payloads:
elif model == "deepseek-reasoner" and "tools" in payloads:
del payloads["tools"]
completion = await self.client.chat.completions.create(
@@ -140,12 +137,6 @@ class ProviderOpenAIOfficial(Provider):
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
to_del = []
for key in payloads.keys():
if key not in self.default_params:

View File

@@ -1,5 +1,4 @@
import aiohttp
from astrbot import logger
from ..provider import RerankProvider
from ..register import register_provider_adapter
from ..entities import ProviderType, RerankResult
@@ -45,11 +44,6 @@ class VLLMRerankProvider(RerankProvider):
response_data = await response.json()
results = response_data.get("results", [])
if not results:
logger.warning(
f"Rerank API 返回了空的列表数据。原始响应: {response_data}"
)
return [
RerankResult(
index=result["index"],

View File

@@ -27,16 +27,14 @@ class Star(CommandParserMixin):
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
@staticmethod
async def text_to_image(text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=self.context._config.get("t2i_active_template"),
)
return await html_renderer.render_t2i(text, return_url=return_url)
@staticmethod
async def html_render(
self, tmpl: str, data: dict, return_url=True, options: dict | None = None
tmpl: str, data: dict, return_url=True, options: dict | None = None
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(

View File

@@ -6,7 +6,6 @@ from astrbot.core.provider.provider import (
TTSProvider,
STTProvider,
EmbeddingProvider,
RerankProvider,
)
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
@@ -24,7 +23,7 @@ from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable, Any, Callable
from typing import Awaitable
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
@@ -104,14 +103,9 @@ class Context:
"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(
self, provider_id: str
) -> (
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
):
"""通过 ID 获取对应的 LLM Provider。"""
prov = self.provider_manager.inst_map.get(provider_id)
return prov
def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
def get_all_providers(self) -> List[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
@@ -136,43 +130,34 @@ class Context:
Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
"""
prov = self.provider_manager.get_using_provider(
return self.provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if prov and not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
"""
获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
prov = self.provider_manager.get_using_provider(
return self.provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
if prov and not isinstance(prov, TTSProvider):
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
return prov
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
"""
获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
prov = self.provider_manager.get_using_provider(
return self.provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
if prov and not isinstance(prov, STTProvider):
raise ValueError("返回的 Provider 不是 STTProvider 类型")
return prov
def get_config(self, umo: str | None = None) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
@@ -260,11 +245,7 @@ class Context:
"""
def register_llm_tool(
self,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
self, name: str, func_args: list, desc: str, func_obj: Awaitable
) -> None:
"""
为函数调用function-calling / tools-use添加工具。
@@ -286,7 +267,9 @@ class Context:
desc=desc,
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
self.provider_manager.llm_tools.add_func(
name, func_args, desc, func_obj, func_obj
)
def unregister_llm_tool(self, name: str) -> None:
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
@@ -298,7 +281,7 @@ class Context:
command_name: str,
desc: str,
priority: int,
awaitable: Callable[..., Awaitable[Any]],
awaitable: Awaitable,
use_regex=False,
ignore_prefix=False,
):

View File

@@ -32,9 +32,6 @@ class CommandFilter(HandlerFilter):
self.init_handler_md(handler_md)
self.custom_filter_list: List[CustomFilter] = []
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def print_types(self):
result = ""
for k, v in self.handler_params.items():
@@ -139,22 +136,6 @@ class CommandFilter(HandlerFilter):
)
return result
def get_complete_command_names(self):
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
self._cmpl_cmd_names = [
f"{parent} {cmd}" if parent else cmd
for cmd in [self.command_name] + list(self.alias)
for parent in self.parent_command_names or [""]
]
return self._cmpl_cmd_names
def equals(self, message_str: str) -> bool:
for full_cmd in self.get_complete_command_names():
if message_str == full_cmd:
return True
return False
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -164,11 +145,18 @@ class CommandFilter(HandlerFilter):
# 检查是否以指令开头
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
candidates = [self.command_name] + list(self.alias)
ok = False
for full_cmd in self.get_complete_command_names():
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
ok = True
message_str = message_str[len(full_cmd) :].strip()
for candidate in candidates:
for parent_command_name in self.parent_command_names:
if parent_command_name:
_full = f"{parent_command_name} {candidate}"
else:
_full = candidate
if message_str.startswith(f"{_full} ") or message_str == _full:
message_str = message_str[len(_full) :].strip()
ok = True
break
if not ok:
return False

View File

@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
def __init__(
self,
group_name: str,
alias: set | None = None,
parent_group: CommandGroupFilter | None = None,
alias: set = None,
parent_group: CommandGroupFilter = None,
):
self.group_name = group_name
self.alias = alias if alias else set()
@@ -22,9 +22,6 @@ class CommandGroupFilter(HandlerFilter):
self.custom_filter_list: List[CustomFilter] = []
self.parent_group = parent_group
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def add_sub_command_filter(
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
):
@@ -37,9 +34,6 @@ class CommandGroupFilter(HandlerFilter):
"""遍历父节点获取完整的指令名。
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
parent_cmd_names = (
self.parent_group.get_complete_command_names() if self.parent_group else []
)
@@ -53,7 +47,6 @@ class CommandGroupFilter(HandlerFilter):
for parent_cmd_name in parent_cmd_names:
for candidate in candidates:
result.append(parent_cmd_name + " " + candidate)
self._cmpl_cmd_names = result
return result
# 以树的形式打印出来
@@ -61,8 +54,8 @@ class CommandGroupFilter(HandlerFilter):
self,
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
prefix: str = "",
event: AstrMessageEvent | None = None,
cfg: AstrBotConfig | None = None,
event: AstrMessageEvent = None,
cfg: AstrBotConfig = None,
) -> str:
result = ""
for sub_filter in sub_command_filters:
@@ -104,12 +97,6 @@ class CommandGroupFilter(HandlerFilter):
return False
return True
def startswith(self, message_str: str) -> bool:
return message_str.startswith(tuple(self.get_complete_command_names()))
def equals(self, message_str: str) -> bool:
return message_str in self.get_complete_command_names()
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -118,14 +105,18 @@ class CommandGroupFilter(HandlerFilter):
if not self.custom_filter_ok(event, cfg):
return False
if self.equals(event.message_str.strip()):
complete_command_names = self.get_complete_command_names()
if event.message_str.strip() in complete_command_names:
tree = (
self.group_name
+ "\n"
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
)
raise ValueError(
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
+ tree
)
return self.startswith(event.message_str)
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False

View File

@@ -2,6 +2,7 @@ import enum
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from typing import Union
class PlatformAdapterType(enum.Flag):
@@ -17,8 +18,6 @@ class PlatformAdapterType(enum.Flag):
KOOK = enum.auto()
VOCECHAT = enum.auto()
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
SATORI = enum.auto()
MISSKEY = enum.auto()
ALL = (
AIOCQHTTP
| QQOFFICIAL
@@ -32,8 +31,6 @@ class PlatformAdapterType(enum.Flag):
| KOOK
| VOCECHAT
| WEIXIN_OFFICIAL_ACCOUNT
| SATORI
| MISSKEY
)
@@ -50,20 +47,15 @@ ADAPTER_NAME_2_TYPE = {
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
"vocechat": PlatformAdapterType.VOCECHAT,
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
"satori": PlatformAdapterType.SATORI,
"misskey": PlatformAdapterType.MISSKEY,
}
class PlatformAdapterTypeFilter(HandlerFilter):
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
if isinstance(platform_adapter_type_or_str, str):
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
else:
self.platform_type = platform_adapter_type_or_str
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
self.type_or_str = platform_adapter_type_or_str
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
adapter_name = event.get_platform_name()
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
if adapter_name in ADAPTER_NAME_2_TYPE:
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
return False

View File

@@ -8,7 +8,6 @@ from .star_handler import (
register_permission_type,
register_custom_filter,
register_on_astrbot_loaded,
register_on_platform_loaded,
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
@@ -27,7 +26,6 @@ __all__ = [
"register_permission_type",
"register_custom_filter",
"register_on_astrbot_loaded",
"register_on_platform_loaded",
"register_on_llm_request",
"register_on_llm_response",
"register_llm_tool",

View File

@@ -5,9 +5,7 @@ from astrbot.core.star import StarMetadata, star_map
_warned_register_star = False
def register_star(
name: str, author: str, desc: str, version: str, repo: str | None = None
):
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
"""注册一个插件(Star)。
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。

View File

@@ -12,7 +12,7 @@ from ..filter.platform_adapter_type import (
from ..filter.permission import PermissionTypeFilter, PermissionType
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
from ..filter.regex import RegexFilter
from typing import Awaitable, Any, Callable
from typing import Awaitable
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core.agent.agent import Agent
@@ -20,19 +20,15 @@ from astrbot.core.agent.tool import FunctionTool
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core import logger
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
def get_handler_full_name(awaitable: Awaitable) -> str:
"""获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(
handler: Callable[..., Awaitable[Any]],
event_type: EventType,
dont_add=False,
**kwargs,
handler: Awaitable, event_type: EventType, dont_add=False, **kwargs
) -> StarHandlerMetadata:
"""获取 Handler 或者创建一个新的 Handler"""
handler_full_name = get_handler_full_name(handler)
@@ -63,35 +59,22 @@ def get_handler_or_create(
def register_command(
command_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
command_name: str = None, sub_command: str = None, alias: set = None, **kwargs
):
"""注册一个 Command."""
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
# 子指令
if sub_command is not None:
parent_command_names = (
command_name.parent_group.get_complete_command_names()
)
new_command = CommandFilter(
sub_command, alias, None, parent_command_names=parent_command_names
)
command_name.parent_group.add_sub_command_filter(new_command)
else:
logger.warning(
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。"
)
parent_command_names = command_name.parent_group.get_complete_command_names()
new_command = CommandFilter(
sub_command, alias, None, parent_command_names=parent_command_names
)
command_name.parent_group.add_sub_command_filter(new_command)
else:
# 裸指令
if command_name is None:
logger.warning("注册裸指令时未提供 command_name 参数。")
else:
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
def decorator(awaitable):
if not add_to_event_filters:
@@ -101,9 +84,8 @@ def register_command(
handler_md = get_handler_or_create(
awaitable, EventType.AdapterMessageEvent, **kwargs
)
if new_command:
new_command.init_handler_md(handler_md)
handler_md.event_filters.append(new_command)
new_command.init_handler_md(handler_md)
handler_md.event_filters.append(new_command)
return awaitable
return decorator
@@ -181,38 +163,26 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
def register_command_group(
command_group_name: str | None = None,
sub_command: str | None = None,
alias: set | None = None,
**kwargs,
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
):
"""注册一个 CommandGroup"""
new_group = None
if isinstance(command_group_name, RegisteringCommandable):
# 子指令组
if sub_command is None:
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
else:
new_group = CommandGroupFilter(
sub_command, alias, parent_group=command_group_name.parent_group
)
command_group_name.parent_group.add_sub_command_filter(new_group)
new_group = CommandGroupFilter(
sub_command, alias, parent_group=command_group_name.parent_group
)
command_group_name.parent_group.add_sub_command_filter(new_group)
else:
# 根指令组
if command_group_name is None:
logger.warning("根指令组的名称未指定")
else:
new_group = CommandGroupFilter(command_group_name, alias)
new_group = CommandGroupFilter(command_group_name, alias)
def decorator(obj):
# 根指令组
if new_group:
handler_md = get_handler_or_create(
obj, EventType.AdapterMessageEvent, **kwargs
)
handler_md.event_filters.append(new_group)
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
return RegisteringCommandable(new_group)
return decorator
@@ -297,18 +267,6 @@ def register_on_astrbot_loaded(**kwargs):
return decorator
def register_on_platform_loaded(**kwargs):
"""
当平台加载完成时
"""
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnPlatformLoadedEvent, **kwargs)
return awaitable
return decorator
def register_on_llm_request(**kwargs):
"""当有 LLM 请求时的事件
@@ -353,7 +311,7 @@ def register_on_llm_response(**kwargs):
return decorator
def register_llm_tool(name: str | None = None, **kwargs):
def register_llm_tool(name: str = None, **kwargs):
"""为函数调用function-calling / tools-use添加工具。
请务必按照以下格式编写一个工具包括函数注释AstrBot 会尝试解析该函数注释)
@@ -391,10 +349,9 @@ def register_llm_tool(name: str | None = None, **kwargs):
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Callable[..., Awaitable[Any]]):
def decorator(awaitable: Awaitable):
llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc)
docstring = docstring_parser.parse(awaitable.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
@@ -410,18 +367,18 @@ def register_llm_tool(name: str | None = None, **kwargs):
)
# print(llm_tool_name, registering_agent)
if not registering_agent:
doc_desc = docstring.description.strip() if docstring.description else ""
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
llm_tools.add_func(
llm_tool_name, args, docstring.description.strip(), md.handler
)
else:
assert isinstance(registering_agent, RegisteringAgent)
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
if registering_agent._agent.tools is None:
registering_agent._agent.tools = []
desc = docstring.description.strip() if docstring.description else ""
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
registering_agent._agent.tools.append(tool)
registering_agent._agent.tools.append(llm_tools.spec_to_func(
llm_tool_name, args, docstring.description.strip(), awaitable
))
return awaitable
@@ -442,8 +399,8 @@ class RegisteringAgent:
def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
tools: list[str | FunctionTool] = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
):
"""注册一个 Agent
@@ -455,7 +412,7 @@ def register_agent(
"""
tools_ = tools or []
def decorator(awaitable: Callable[..., Awaitable[Any]]):
def decorator(awaitable: Awaitable):
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,
@@ -464,7 +421,7 @@ def register_agent(
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable
handoff_tool.handler=awaitable
llm_tools.func_list.append(handoff_tool)
return RegisteringAgent(agent)

View File

@@ -52,6 +52,10 @@ class SessionServiceManager:
"session_service_config", session_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
)
@staticmethod
def should_process_llm_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理LLM请求

View File

@@ -84,10 +84,7 @@ class SessionPluginManager:
session_config["disabled_plugins"] = disabled_plugins
session_plugin_config[session_id] = session_config
sp.put(
"session_plugin_config",
session_plugin_config,
scope="umo",
scope_id=session_id,
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
)
logger.info(
@@ -140,9 +137,6 @@ class SessionPluginManager:
filtered_handlers.append(handler)
continue
if plugin.name is None:
continue
# 检查插件是否在当前会话中启用
if SessionPluginManager.is_plugin_enabled_for_session(
session_id, plugin.name

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
@@ -34,33 +34,26 @@ class StarHandlerRegistry(Generic[T]):
) -> List[StarHandlerMetadata]:
handlers = []
for handler in self._handlers:
# 过滤事件类型
if handler.event_type != event_type:
continue
# 过滤启用状态
if only_activated:
plugin = star_map.get(handler.handler_module_path)
if not (plugin and plugin.activated):
continue
# 过滤插件白名单
if plugins_name is not None and plugins_name != ["*"]:
plugin = star_map.get(handler.handler_module_path)
if not plugin:
continue
if (
plugin.name not in plugins_name
and event_type
not in (
EventType.OnAstrBotLoadedEvent,
EventType.OnPlatformLoadedEvent,
)
and event_type != EventType.OnAstrBotLoadedEvent
and not plugin.reserved
):
continue
handlers.append(handler)
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None:
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
return self.star_handlers_map.get(full_name, None)
def get_handlers_by_module_name(
@@ -87,7 +80,7 @@ class StarHandlerRegistry(Generic[T]):
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry() # type: ignore
star_handlers_registry = StarHandlerRegistry()
class EventType(enum.Enum):
@@ -97,7 +90,6 @@ class EventType(enum.Enum):
"""
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
@@ -123,7 +115,7 @@ class StarHandlerMetadata:
handler_module_path: str
"""Handler 所在的模块路径。"""
handler: Callable[..., Awaitable[Any]]
handler: Awaitable
"""Handler 的函数对象,应当是一个异步函数"""
event_filters: List[HandlerFilter]

View File

@@ -43,7 +43,7 @@ class PluginManager:
self.updator = PluginUpdator()
self.context = context
self.context._star_manager = self # type: ignore
self.context._star_manager = self
self.config = config
self.plugin_store_path = get_astrbot_plugin_path()
@@ -478,10 +478,9 @@ class PluginManager:
if isinstance(func_tool, HandoffTool):
need_apply = []
sub_tools = func_tool.agent.tools
if sub_tools:
for sub_tool in sub_tools:
if isinstance(sub_tool, FunctionTool):
need_apply.append(sub_tool)
for sub_tool in sub_tools:
if isinstance(sub_tool, FunctionTool):
need_apply.append(sub_tool)
else:
need_apply = [func_tool]
@@ -687,9 +686,6 @@ class PluginManager:
)
# 从 star_registry 和 star_map 中删除
if plugin.module_path is None or root_dir_name is None:
raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。")
await self._unbind_plugin(plugin_name, plugin.module_path)
try:
@@ -795,17 +791,15 @@ class PluginManager:
if star_metadata.star_cls is None:
return
if "__del__" in star_metadata.star_cls_type.__dict__:
if '__del__' in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
elif "terminate" in star_metadata.star_cls_type.__dict__:
elif 'terminate' in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if plugin is None:
raise Exception(f"插件 {plugin_name} 不存在。")
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:

View File

@@ -1,41 +1,14 @@
"""
插件开发工具集
封装了许多常用的操作,方便插件开发者使用
说明:
主动发送消息: send_message(session, message_chain)
根据 session (unified_msg_origin) 主动发送消息, 前提是需要提前获得或构造 session
根据id直接主动发送消息: send_message_by_id(type, id, message_chain, platform="aiocqhttp")
根据 id (例如 qq 号, 群号等) 直接, 主动地发送消息
以上两种方式需要构造消息链, 也就是消息组件的列表
构造事件:
首先需要构造一个 AstrBotMessage 对象, 使用 create_message 方法
然后使用 create_event 方法提交事件到指定平台
"""
import inspect
import os
import uuid
from pathlib import Path
from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar
from typing import Union, Awaitable, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
from astrbot.api.platform import MessageMember, AstrBotMessage
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.star.context import Context
from astrbot.core.star.star import star_map
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import (
AiocqhttpMessageEvent,
)
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import (
AiocqhttpAdapter,
)
class StarTools:
@@ -76,82 +49,42 @@ class StarTools:
Note:
qq_official(QQ官方API平台)不支持此方法
"""
if cls._context is None:
raise ValueError("StarTools not initialized")
return await cls._context.send_message(session, message_chain)
@classmethod
async def send_message_by_id(
cls,
type: str,
id: str,
message_chain: MessageChain,
platform: str = "aiocqhttp",
):
"""
根据 id(例如qq号, 群号等) 直接, 主动地发送消息
Args:
type (str): 消息类型, 可选: PrivateMessage, GroupMessage
id (str): 目标ID, 例如QQ号, 群号等
message_chain (MessageChain): 消息链
platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp
"""
if cls._context is None:
raise ValueError("StarTools not initialized")
platforms = cls._context.platform_manager.get_insts()
if platform == "aiocqhttp":
adapter = next(
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
)
if adapter is None:
raise ValueError("未找到适配器: AiocqhttpAdapter")
await AiocqhttpMessageEvent.send_message(
bot=adapter.bot,
message_chain=message_chain,
is_group=(type == "GroupMessage"),
session_id=id,
)
else:
raise ValueError(f"不支持的平台: {platform}")
@classmethod
async def create_message(
cls,
type: str,
self_id: str,
session_id: str,
message_id: str,
sender: MessageMember,
message: List[BaseMessageComponent],
message_str: str,
message_id: str = "",
raw_message: object = None,
raw_message: object,
group_id: str = "",
) -> AstrBotMessage:
):
"""
创建一个AstrBot消息对象
Args:
type (str): 消息类型, 例如 "GroupMessage" "FriendMessage" "OtherMessage"
type (str): 消息类型
self_id (str): 机器人自身ID
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
sender (MessageMember): 发送者信息, 例如 MessageMember(user_id="123456", nickname="昵称")
message (List[BaseMessageComponent]): 消息组件列表, 也就是消息链, 这个不会发给 llm, 但是会经过其他处理
message_str (str): 消息字符串, 也就是纯文本消息, 也就是发送给 llm 的消息, 与消息链一致
message_id (str): 消息ID, 构造消息时可以随意填写也可不填
raw_message (object): 原始消息对象, 可以随意填写也可不填
message_id (str): 消息ID
sender (MessageMember): 发送者信息
message (List[BaseMessageComponent]): 消息组件列表
message_str (str): 消息字符串
raw_message (object): 原始消息对象
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
Returns:
AstrBotMessage: 创建的消息对象
"""
abm = AstrBotMessage()
abm.type = MessageType(type)
abm.type = type
abm.self_id = self_id
abm.session_id = session_id
if message_id == "":
message_id = uuid.uuid4().hex
abm.message_id = message_id
abm.sender = sender
abm.message = message
@@ -160,39 +93,13 @@ class StarTools:
abm.group_id = group_id
return abm
@classmethod
async def create_event(
cls, abm: AstrBotMessage, platform: str = "aiocqhttp", is_wake: bool = True
) -> None:
"""
创建并提交事件到指定平台
当有需要创建一个事件, 触发某些处理流程时, 使用该方法
# todo: 添加构造事件的方法
# async def create_event(
# self, platform: str, umo: str, sender_id: str, session_id: str
# ):
# platform = self._context.get_platform(platform)
Args:
abm (AstrBotMessage): 要提交的消息对象, 请先使用 create_message 创建
platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp
is_wake (bool): 是否标记为唤醒事件, 默认为 True, 只有唤醒事件才会被 llm 响应
"""
if cls._context is None:
raise ValueError("StarTools not initialized")
platforms = cls._context.platform_manager.get_insts()
if platform == "aiocqhttp":
adapter = next(
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
)
if adapter is None:
raise ValueError("未找到适配器: AiocqhttpAdapter")
event = AiocqhttpMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=adapter.metadata,
session_id=abm.session_id,
bot=adapter.bot,
)
event.is_wake = is_wake
adapter.commit_event(event)
else:
raise ValueError(f"不支持的平台: {platform}")
# todo: 添加找到对应平台并提交对应事件的方法
@classmethod
def activate_llm_tool(cls, name: str) -> bool:
@@ -203,8 +110,6 @@ class StarTools:
Args:
name (str): 工具名称
"""
if cls._context is None:
raise ValueError("StarTools not initialized")
return cls._context.activate_llm_tool(name)
@classmethod
@@ -215,17 +120,11 @@ class StarTools:
Args:
name (str): 工具名称
"""
if cls._context is None:
raise ValueError("StarTools not initialized")
return cls._context.deactivate_llm_tool(name)
@classmethod
def register_llm_tool(
cls,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
) -> None:
"""
为函数调用function-calling/tools-use添加工具
@@ -236,8 +135,6 @@ class StarTools:
desc (str): 工具描述
func_obj (Awaitable): 函数对象,必须是异步函数
"""
if cls._context is None:
raise ValueError("StarTools not initialized")
cls._context.register_llm_tool(name, func_args, desc, func_obj)
@classmethod
@@ -249,8 +146,6 @@ class StarTools:
Args:
name (str): 工具名称
"""
if cls._context is None:
raise ValueError("StarTools not initialized")
cls._context.unregister_llm_tool(name)
@classmethod
@@ -274,11 +169,8 @@ class StarTools:
- 创建目录失败权限不足或其他IO错误
"""
if not plugin_name:
frame = inspect.currentframe()
module = None
if frame:
frame = frame.f_back
module = inspect.getmodule(frame)
frame = inspect.currentframe().f_back
module = inspect.getmodule(frame)
if not module:
raise RuntimeError("无法获取调用者模块信息")
@@ -290,12 +182,7 @@ class StarTools:
plugin_name = metadata.name
if not plugin_name:
raise ValueError("无法获取插件名称")
data_dir = Path(
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
)
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
try:
data_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -32,9 +32,6 @@ class PluginUpdator(RepoZipUpdator):
if not repo_url:
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
if not plugin.root_dir_name:
raise Exception(f"插件 {plugin.name} 的根目录名未指定。")
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")

View File

@@ -56,7 +56,9 @@ class AstrBotUpdator(RepoZipUpdator):
try:
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
if os.name == "nt":
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
args = [
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
]
else:
args = sys.argv[1:]
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
@@ -66,13 +68,9 @@ class AstrBotUpdator(RepoZipUpdator):
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
raise e
async def check_update(
self, url: str, current_version: str, consider_prerelease: bool = True
) -> ReleaseInfo:
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
"""检查更新"""
return await super().check_update(
self.ASTRBOT_RELEASE_API, VERSION, consider_prerelease
)
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
async def get_releases(self) -> list:
return await self.fetch_release_info(self.ASTRBOT_RELEASE_API)

View File

@@ -1,33 +1,9 @@
import codecs
import json
from astrbot.core import logger
from aiohttp import ClientSession, ClientResponse
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
decoder = codecs.getincrementaldecoder("utf-8")()
buffer = ""
async for chunk in resp.content.iter_chunked(8192):
buffer += decoder.decode(chunk)
while "\n\n" in buffer:
block, buffer = buffer.split("\n\n", 1)
if block.strip().startswith("data:"):
try:
yield json.loads(block[5:])
except json.JSONDecodeError:
logger.warning(f"Drop invalid dify json data: {block[5:]}")
continue
# flush any remaining text
buffer += decoder.decode(b"", final=True)
if buffer.strip().startswith("data:"):
try:
yield json.loads(buffer[5:])
except json.JSONDecodeError:
logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
pass
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
@@ -57,11 +33,31 @@ class DifyAPIClient:
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(
f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
)
async for event in _stream_sse(resp):
yield event
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode("utf-8")
blocks = buffer.split("\n\n")
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith("data:"):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
async def workflow_run(
self,
@@ -81,11 +77,31 @@ class DifyAPIClient:
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(
f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
)
async for event in _stream_sse(resp):
yield event
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode("utf-8")
blocks = buffer.split("\n\n")
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith("data:"):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
async def file_upload(
self,
@@ -93,15 +109,12 @@ class DifyAPIClient:
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
with open(file_path, "rb") as f:
payload = {
"user": user,
"file": f,
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(url, data=payload, headers=self.headers) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()

View File

@@ -227,11 +227,9 @@ async def download_dashboard(
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
if latest or len(str(version)) != 40:
logger.info("准备下载最新发行版本的 AstrBot WebUI")
ver_name = "latest" if latest else version
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
logger.info(
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}"
)
try:
await download_file(dashboard_release_url, path, show_progress=True)
except BaseException as _:
@@ -243,10 +241,24 @@ async def download_dashboard(
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
await download_file(dashboard_release_url, path, show_progress=True)
else:
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
url = (
"https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
)
if proxy:
url = f"{proxy}/{url}"
await download_file(url, path, show_progress=True)
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as resp:
if resp.status == 200:
releases = await resp.json()
for release in releases:
if version in release["tag_name"]:
download_url = release["assets"][0]["browser_download_url"]
await download_file(download_url, path, show_progress=True)
else:
logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
return
with zipfile.ZipFile(path, "r") as z:
z.extractall(extract_path)

View File

@@ -1,5 +1,6 @@
import aiohttp
import asyncio
import os
import ssl
import certifi
import logging
@@ -7,9 +8,10 @@ import random
from . import RenderStrategy
from astrbot.core.config import VERSION
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.t2i.template_manager import TemplateManager
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
CUSTOM_T2I_TEMPLATE_PATH = os.path.join(get_astrbot_data_path(), "t2i_template.html")
logger = logging.getLogger("astrbot")
@@ -21,17 +23,26 @@ class NetworkRenderStrategy(RenderStrategy):
self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT
else:
self.BASE_RENDER_URL = self._clean_url(base_url)
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template", "base.html")
with open(self.TEMPLATE_PATH, "r", encoding="utf-8") as f:
self.DEFAULT_TEMPLATE = f.read()
self.endpoints = [self.BASE_RENDER_URL]
self.template_manager = TemplateManager()
async def initialize(self):
if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT:
asyncio.create_task(self.get_official_endpoints())
async def get_template(self, name: str = "base") -> str:
"""通过名称获取文转图 HTML 模板"""
return self.template_manager.get_template(name)
async def get_template(self) -> str:
"""获取文转图 HTML 模板
Returns:
str: 文转图 HTML 模板字符串
"""
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
with open(CUSTOM_T2I_TEMPLATE_PATH, "r", encoding="utf-8") as f:
return f.read()
return self.DEFAULT_TEMPLATE
async def get_official_endpoints(self):
"""获取官方的 t2i 端点列表。"""
@@ -113,15 +124,11 @@ class NetworkRenderStrategy(RenderStrategy):
logger.error(f"All endpoints failed: {last_exception}")
raise RuntimeError(f"All endpoints failed: {last_exception}")
async def render(
self, text: str, return_url: bool = False, template_name: str | None = "base"
) -> str:
async def render(self, text: str, return_url: bool = False) -> str:
"""
返回图像的文件路径
"""
if not template_name:
template_name = "base"
tmpl_str = await self.get_template(name=template_name)
tmpl_str = await self.get_template()
text = text.replace("`", "\\`")
return await self.render_custom_template(
tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url

View File

@@ -34,18 +34,12 @@ class HtmlRenderer:
)
async def render_t2i(
self,
text: str,
use_network: bool = True,
return_url: bool = False,
template_name: str | None = None,
self, text: str, use_network: bool = True, return_url: bool = False
):
"""使用默认文转图模板。"""
if use_network:
try:
return await self.network_strategy.render(
text, return_url=return_url, template_name=template_name
)
return await self.network_strategy.render(text, return_url=return_url)
except BaseException as e:
logger.error(
f"Failed to render image via AstrBot API: {e}. Falling back to local rendering."

View File

@@ -1,184 +0,0 @@
<!doctype html>
<html>
<head>
<meta charset="utf-8"/>
<title>Astrbot PowerShell {{ version }} </title>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.css" integrity="sha384-wcIxkf4k558AjM3Yz3BBFQUbk/zgIYC2R0QpeeYb+TwlBVMrlgLqwRjRtGZiK7ww" crossorigin="anonymous">
<script src="https://cdn.jsdelivr.net/npm/highlight.js@11.9.0/lib/common.min.js"></script>
<script>hljs.highlightAll();</script>
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.js" integrity="sha384-hIoBPJpTUs74ddyc4bFZSM1TVlQDA60VBbJS0oA934VSz82sBx1X7kSx2ATBDIyd" crossorigin="anonymous"></script>
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/contrib/auto-render.min.js" integrity="sha384-43gviWU0YVjaDtb/GhzOouOXtZMP/7XUzwPTstBeZFe/+rCMvRwr4yROQP43s0Xk" crossorigin="anonymous"
onload="renderMathInElement(document.getElementById('content'),{delimiters: [{left: '$$', right: '$$', display: true},{left: '$', right: '$', display: false}]});"></script>
<style>
:root {
--bg-color: #010409;
--text-color: #e6edf3;
--title-bar-color: #161b22;
--title-text-color: #e6edf3;
--font-family: 'Consolas', 'Microsoft YaHei Mono', 'Dengxian Mono', 'Courier New', monospace;
--glow-color: rgba(200, 220, 255, 0.7);
}
@keyframes scanline {
0% {
background-position: 0 0;
}
100% {
background-position: 0 100%;
}
}
body {
background-color: var(--bg-color);
color: var(--text-color);
font-family: var(--font-family);
margin: 0;
padding: 0;
line-height: 1.6;
font-size: 18px;
/* The CRT glow effect from the image */
text-shadow: 0 0 15px var(--glow-color), 0 0 7px rgba(255, 255, 255, 1);
position: relative;
overflow: hidden;
}
body::after {
content: " ";
display: block;
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: linear-gradient(to bottom, transparent 50%, rgba(0, 0, 0, 0.3) 50%);
background-size: 100% 4px;
z-index: 2;
pointer-events: none;
animation: scanline 8s linear infinite;
}
.header {
background-color: var(--title-bar-color);
padding: 12px 18px;
color: var(--title-text-color);
font-size: 16px;
border-bottom: 1px solid #30363d;
text-shadow: none; /* No glow for title bar */
}
.header .title {
font-weight: bold;
font-size: 28px;
}
.header .version {
opacity: 0.8;
margin-left: 1rem;
}
main {
padding: 1rem 1.5rem;
}
#content {
/* min-width and max-width removed as per request */
}
/* --- Markdown Styles adjusted for terminal look --- */
h1, h2, h3, h4, h5, h6 {
line-height: 1.4;
margin-top: 20px;
margin-bottom: 10px;
padding-bottom: 5px;
border-bottom: 1px solid #30363d;
color: var(--text-color);
}
h1 { font-size: 2rem; }
h2 { font-size: 1.7rem; }
h3 { font-size: 1.4rem; }
p {
margin-top: 1rem;
margin-bottom: 1rem;
}
strong {
color: var(--text-color);
font-weight: bold;
}
img {
max-width: 100%;
border: 1px solid #30363d;
display: block;
margin: 1rem auto;
}
hr {
border: 0;
border-top: 1px dashed #30363d;
margin: 2rem 0;
}
code {
font-family: var(--font-family);
padding: 0.2em 0.4em;
margin: 0;
font-size: 90%;
background-color: #161b22;
border-radius: 4px;
}
pre {
font-family: var(--font-family);
border-radius: 4px;
background: #0d1117;
padding: 1rem;
overflow-x: auto;
border: 1px solid #30363d;
}
pre > code {
padding: 0;
margin: 0;
font-size: 100%;
background-color: transparent;
border-radius: 0;
text-shadow: none; /* Disable glow inside code blocks for clarity */
}
a {
color: #58a6ff;
text-decoration: underline;
}
a:hover {
text-decoration: underline;
}
blockquote {
border-left: 4px solid #30363d;
padding: 0.5rem 1rem;
margin: 1.5rem 0;
color: #8b949e;
background-color: #161b22;
}
</style>
</head>
<body>
<div class="header">
<span class="title">> Astrbot PowerShell</span>
<span class="version">{{ version }}</span>
</div>
<main>
<div id="content"></div>
</main>
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<script>
document.getElementById('content').innerHTML = marked.parse(`{{ text | safe }}`);
</script>
</body>
</html>

View File

@@ -1,112 +0,0 @@
# astrbot/core/utils/t2i/template_manager.py
import os
import shutil
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path
class TemplateManager:
"""
负责管理 t2i HTML 模板的 CRUD 和重置操作。
采用“用户覆盖内置”策略:用户模板存储在 data 目录中,并优先于内置模板加载。
所有创建、更新、删除操作仅影响用户目录,以确保更新框架时用户数据安全。
"""
CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"]
def __init__(self):
self.builtin_template_dir = os.path.join(
get_astrbot_path(), "astrbot", "core", "utils", "t2i", "template"
)
self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates")
os.makedirs(self.user_template_dir, exist_ok=True)
self._initialize_user_templates()
def _copy_core_templates(self, overwrite: bool = False):
"""从内置目录复制核心模板到用户目录。"""
for filename in self.CORE_TEMPLATES:
src = os.path.join(self.builtin_template_dir, filename)
dst = os.path.join(self.user_template_dir, filename)
if os.path.exists(src) and (overwrite or not os.path.exists(dst)):
shutil.copyfile(src, dst)
def _initialize_user_templates(self):
"""如果用户目录下缺少核心模板,则进行复制。"""
self._copy_core_templates(overwrite=False)
def _get_user_template_path(self, name: str) -> str:
"""获取用户模板的完整路径,防止路径遍历漏洞。"""
if ".." in name or "/" in name or "\\" in name:
raise ValueError("模板名称包含非法字符。")
return os.path.join(self.user_template_dir, f"{name}.html")
def _read_file(self, path: str) -> str:
"""读取文件内容。"""
with open(path, "r", encoding="utf-8") as f:
return f.read()
def list_templates(self) -> list[dict]:
"""
列出所有可用模板。
该列表是内置模板和用户模板的合并视图,用户模板将覆盖同名的内置模板。
"""
dirs_to_scan = [self.builtin_template_dir, self.user_template_dir]
all_names = {
os.path.splitext(f)[0]
for d in dirs_to_scan
for f in os.listdir(d)
if f.endswith(".html")
}
return [
{"name": name, "is_default": name == "base"} for name in sorted(all_names)
]
def get_template(self, name: str) -> str:
"""
获取指定模板的内容。
优先从用户目录加载,如果不存在则回退到内置目录。
"""
user_path = self._get_user_template_path(name)
if os.path.exists(user_path):
return self._read_file(user_path)
builtin_path = os.path.join(self.builtin_template_dir, f"{name}.html")
if os.path.exists(builtin_path):
return self._read_file(builtin_path)
raise FileNotFoundError("模板不存在。")
def create_template(self, name: str, content: str):
"""在用户目录中创建一个新的模板文件。"""
path = self._get_user_template_path(name)
if os.path.exists(path):
raise FileExistsError("同名模板已存在。")
with open(path, "w", encoding="utf-8") as f:
f.write(content)
def update_template(self, name: str, content: str):
"""
更新一个模板。此操作始终写入用户目录。
如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本,
从而实现对内置模板的“覆盖”。
"""
path = self._get_user_template_path(name)
with open(path, "w", encoding="utf-8") as f:
f.write(content)
def delete_template(self, name: str):
"""
仅删除用户目录中的模板文件。
如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。
"""
path = self._get_user_template_path(name)
if not os.path.exists(path):
raise FileNotFoundError("用户模板不存在,无法删除。")
os.remove(path)
def reset_default_template(self):
"""
将核心模板从内置目录强制重置到用户目录。
"""
self._copy_core_templates(overwrite=True)

View File

@@ -107,38 +107,16 @@ class RepoZipUpdator:
"""Semver 版本比较"""
return VersionComparator.compare_version(v1, v2)
async def check_update(
self, url: str, current_version: str, consider_prerelease: bool = True
) -> ReleaseInfo | None:
async def check_update(self, url: str, current_version: str) -> ReleaseInfo | None:
update_data = await self.fetch_release_info(url)
sel_release_data = None
if consider_prerelease:
tag_name = update_data[0]["tag_name"]
sel_release_data = update_data[0]
else:
for data in update_data:
# 跳过带有 alpha、beta 等预发布标签的版本
if re.search(
r"[\-_.]?(alpha|beta|rc|dev)[\-_.]?\d*$",
data["tag_name"],
re.IGNORECASE,
):
continue
tag_name = data["tag_name"]
sel_release_data = data
break
if not sel_release_data or not tag_name:
logger.error("未找到合适的发布版本")
return None
tag_name = update_data[0]["tag_name"]
if self.compare_version(current_version, tag_name) >= 0:
return None
return ReleaseInfo(
version=tag_name,
published_at=sel_release_data["published_at"],
body=f"{tag_name}\n\n{sel_release_data['body']}",
published_at=update_data[0]["published_at"],
body=update_data[0]["body"],
)
async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""):

View File

@@ -1,27 +1,17 @@
import uuid
import json
import os
import asyncio
from contextlib import asynccontextmanager
from .route import Route, Response, RouteContext
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from quart import request, Response as QuartResponse, g, make_response
from astrbot.core.db import BaseDatabase
import asyncio
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.platform.astr_message_event import MessageSession
@asynccontextmanager
async def track_conversation(convs: dict, conv_id: str):
convs[conv_id] = True
try:
yield
finally:
convs.pop(conv_id, None)
class ChatRoute(Route):
def __init__(
self,
@@ -50,8 +40,6 @@ class ChatRoute(Route):
self.conv_mgr = core_lifecycle.conversation_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.running_convs: dict[str, bool] = {}
async def get_file(self):
filename = request.args.get("filename")
if not filename:
@@ -151,63 +139,38 @@ class ChatRoute(Route):
)
async def stream():
client_disconnected = False
try:
async with track_conversation(self.running_convs, webchat_conv_id):
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True
except Exception as e:
logger.error(f"WebChat stream error: {e}")
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=10)
except asyncio.TimeoutError:
continue
if not result:
continue
if not result:
continue
result_text = result["data"]
type = result.get("type")
streaming = result.get("streaming", False)
result_text = result["data"]
type = result.get("type")
streaming = result.get("streaming", False)
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.05)
try:
if not client_disconnected:
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
except Exception as e:
if not client_disconnected:
logger.debug(
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
)
client_disconnected = True
if type == "end":
break
elif (streaming and type == "complete") or not streaming:
# append bot message
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
try:
if not client_disconnected:
await asyncio.sleep(0.05)
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True
if type == "end":
break
elif (
(streaming and type == "complete")
or not streaming
or type == "break"
):
# append bot message
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
except BaseException as _:
logger.debug(f"用户 {username} 断开聊天长连接。")
return
# Put message to conversation-specific queue
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
@@ -234,7 +197,6 @@ class ChatRoute(Route):
"Connection": "keep-alive",
},
)
response.timeout = None # fix SSE auto disconnect issue
return response
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
@@ -324,7 +286,6 @@ class ChatRoute(Route):
.ok(
data={
"history": history_res,
"is_running": self.running_convs.get(webchat_conv_id, False),
}
)
.__dict__

View File

@@ -16,10 +16,10 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
from astrbot.core import logger, html_renderer
from astrbot.core.provider import Provider
from astrbot.core.provider.provider import RerankProvider
import asyncio
from astrbot.core.utils.t2i.network_strategy import CUSTOM_T2I_TEMPLATE_PATH
def try_cast(value: str, type_: str):
@@ -51,6 +51,24 @@ def validate_config(
def validate(data: dict, metadata: dict = schema, path=""):
for key, value in data.items():
if key not in metadata:
# 无 schema 的配置项,执行类型猜测
if isinstance(value, str):
try:
data[key] = int(value)
continue
except ValueError:
pass
try:
data[key] = float(value)
continue
except ValueError:
pass
if value.lower() == "true":
data[key] = True
elif value.lower() == "false":
data[key] = False
continue
meta = metadata[key]
if "type" not in meta:
@@ -109,12 +127,12 @@ def validate_config(
)
if is_core:
meta_all = {
**schema["platform_group"]["metadata"],
**schema["provider_group"]["metadata"],
**schema["misc_config_group"]["metadata"],
}
validate(data, meta_all)
for key, group in schema.items():
group_meta = group.get("metadata")
if not group_meta:
continue
# logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
else:
validate(data, schema)
@@ -124,7 +142,6 @@ def validate_config(
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
"""验证并保存配置"""
errors = None
logger.info(f"Saving config, is_core={is_core}")
try:
if is_core:
errors, post_config = validate_config(
@@ -138,7 +155,6 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
raise ValueError(f"验证配置时出现异常: {e}")
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.save_config(post_config)
@@ -169,9 +185,56 @@ class ConfigRoute(Route):
"/config/provider/check_one": ("GET", self.check_one_provider_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/model_list": ("GET", self.get_provider_model_list),
"/config/astrbot/t2i-template/get": ("GET", self.get_t2i_template),
"/config/astrbot/t2i-template/save": ("POST", self.post_t2i_template),
"/config/astrbot/t2i-template/delete": ("DELETE", self.delete_t2i_template),
}
self.register_routes()
async def get_t2i_template(self):
"""获取 T2I 模板"""
try:
template = await html_renderer.network_strategy.get_template()
has_custom_template = os.path.exists(CUSTOM_T2I_TEMPLATE_PATH)
return (
Response()
.ok({"template": template, "has_custom_template": has_custom_template})
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取模板失败: {str(e)}").__dict__
async def post_t2i_template(self):
"""保存 T2I 模板"""
try:
post_data = await request.json
if not post_data or "template" not in post_data:
return Response().error("缺少模板内容").__dict__
template_content = post_data["template"]
# 保存自定义模板到文件
with open(CUSTOM_T2I_TEMPLATE_PATH, "w", encoding="utf-8") as f:
f.write(template_content)
return Response().ok(message="模板保存成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"保存模板失败: {str(e)}").__dict__
async def delete_t2i_template(self):
"""删除自定义 T2I 模板,恢复默认模板"""
try:
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
os.remove(CUSTOM_T2I_TEMPLATE_PATH)
return Response().ok(message="已恢复默认模板").__dict__
else:
return Response().ok(message="未找到自定义模板文件").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除模板失败: {str(e)}").__dict__
async def get_abconf_list(self):
"""获取所有 AstrBot 配置文件的列表"""
abconf_list = self.acm.get_conf_list()
@@ -418,19 +481,6 @@ class ConfigRoute(Route):
)
status_info["status"] = "unavailable"
status_info["error"] = f"STT test failed: {str(e)}"
elif provider_capability_type == ProviderType.RERANK:
try:
assert isinstance(provider, RerankProvider)
await provider.rerank("Apple", documents=["apple", "banana"])
status_info["status"] = "available"
except Exception as e:
logger.error(
f"Error testing rerank provider {provider_name}: {e}",
exc_info=True,
)
status_info["status"] = "unavailable"
status_info["error"] = f"Rerank test failed: {str(e)}"
else:
logger.debug(
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}"
@@ -702,13 +752,6 @@ class ConfigRoute(Route):
if conf_id not in self.acm.confs:
raise ValueError(f"配置文件 {conf_id} 不存在")
astrbot_config = self.acm.confs[conf_id]
# 保留服务端的 t2i_active_template 值
if "t2i_active_template" in astrbot_config:
post_configs["t2i_active_template"] = astrbot_config[
"t2i_active_template"
]
save_config(post_configs, astrbot_config, is_core=True)
except Exception as e:
raise e

Some files were not shown because too many files have changed in this diff Show More