Compare commits

...

207 Commits

Author SHA1 Message Date
Soulter
07ba9c772c chore: bump version to 4.5.0 2025-10-26 21:40:11 +08:00
Soulter
0622d88b22 fix: revert 3106 (#3153)
* fix: revert 3106

Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com>
Co-authored-by: exynos <110159911+xiaoxi68@users.noreply.github.com>

* Update astrbot/dashboard/routes/update.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* fix: remove unnecessary version file handling in download_dashboard function

* fix: revert

---------

Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com>
Co-authored-by: exynos <110159911+xiaoxi68@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-26 21:26:48 +08:00
Soulter
594f0fed55 style: adjust padding for card text in ExtensionPage for improved layout 2025-10-26 21:19:07 +08:00
Soulter
04b0d9b88d Merge pull request #3155 from AstrBotDevs/feat/plugin-display-name-and-logo
feat: add support for plugin display name and logo, and some extension card style fix
2025-10-26 20:54:24 +08:00
Soulter
1f2af8ef94 Update dashboard/src/components/shared/ExtensionCard.vue
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-26 20:52:37 +08:00
Soulter
598ea2d857 refactor: update ExtensionCard styling and improve layout for better responsiveness 2025-10-26 20:49:27 +08:00
Soulter
6dd9bbb516 feat: enhance plugin metadata with display name and logo support 2025-10-26 20:30:54 +08:00
Soulter
3cd0b47dc6 feat: add GitHub link button to ExtensionCard for extensions with a repository 2025-10-26 19:41:00 +08:00
Soulter
65c71b5f20 refactor: remove Google search engine integration from main module and dependencies (#3154) 2025-10-26 18:54:01 +08:00
exynos
1152b11202 feat(thinking_filter): 适配第三方 Gemini 思考片段过滤 (#3139)
* feat(thinking_filter): 适配第三方 Gemini 思考片段过滤

* feat(thinking_filter): Gemini 思考过滤、序列化回退与空白清理重构

* 使用 ruff 格式化并修复导入空行
2025-10-26 18:43:58 +08:00
Soulter
51246ea31b fix: apply configuration option to enable/disable WebUI in AstrBotDashboard (#3152) 2025-10-26 17:29:04 +08:00
Soulter
7e5592dd32 fix: comment out existing configuration preview section in AddNewPlatform component 2025-10-26 17:07:04 +08:00
Soulter
c6b28caebf Merge pull request #3151 from AstrBotDevs/feat/platform-abconf-interaction
feat: enhance AddNewPlatform and ConfigPage components with improved configuration management and UI interactions
2025-10-26 17:04:34 +08:00
Soulter
ca002f6fff feat: enhance AddNewPlatform dialog with scroll functionality and toggle for configuration section 2025-10-26 17:03:07 +08:00
Soulter
14ec392091 fix: update message styling in AddNewPlatform component for better visibility 2025-10-26 17:00:36 +08:00
Soulter
5e2eb91ac0 feat: enhance AddNewPlatform and ConfigPage components with improved configuration management and UI interactions 2025-10-26 16:57:01 +08:00
Soulter
c1626613ce fix: update repository references from Soulter/AstrBot to AstrBotDevs/AstrBot across documentation and codebase (#3150)
* fix: update repository references from Soulter/AstrBot to AstrBotDevs/AstrBot across documentation and codebase

- Updated README_ja.md to reflect new GitHub repository links.
- Modified AstrBotUpdator to download from the new repository.
- Changed download URLs in io.py for dashboard releases.
- Updated changelogs to point to the new issue links.
- Adjusted Docker compose file to reference the new repository.
- Updated Vue components in the dashboard to link to the new repository.
- Changed main.py to provide the correct download instructions for the new repository.

* fix: improve error handling for configId selection in AddNewPlatform component

* Update astrbot/core/utils/io.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-26 16:17:24 +08:00
LIghtJUNction
42042d9e73 Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot 2025-10-26 15:41:36 +08:00
LIghtJUNction
22c3b53ab8 fix(io.py): path改回传入文件地址,而不是传入文件夹地址 2025-10-26 15:41:20 +08:00
Soulter
090c32c90e feat: enhance AddNewPlatform dialog with data preparation on enter and improve code formatting 2025-10-26 15:40:15 +08:00
LIghtJUNction
4f4a9b9e55 fix(io.py): download_dashboard如果发现没有dist/assets/version文件,下载完毕自动写入(以防万一) 2025-10-26 15:35:25 +08:00
Soulter
6c7d7c9015 Merge pull request #3147 from AstrBotDevs/feat/kb-markitdown
feat: refactor knowledge base parsers and add MarkitdownParser for docx, xls, xlsx support
2025-10-26 13:18:52 +08:00
Soulter
562e62a8c0 feat: add new dependencies for PDF processing, file handling, and text ranking 2025-10-26 13:02:32 +08:00
Soulter
0823f7aa48 在检查字面量集合的成员资格时使用 set
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-25 22:04:17 +08:00
Soulter
eb201c0420 feat: refactor knowledge base parsers and add MarkitdownParser for docx, xls, xlsx support 2025-10-25 22:00:54 +08:00
Soulter
6cfed9a39d Merge pull request #3143 from lxfight/feature/knowledge-base 2025-10-25 18:19:15 +08:00
Soulter
33618c4a6b feat: add dynamic embedding dimension retrieval for providers and enhance error handling 2025-10-25 16:39:11 +08:00
LI SONGSONG 🍂
ace0a7c219 docs: update link and description 2025-10-25 16:07:27 +08:00
Soulter
f7d018cf94 feat: add pre-checks for embedding and rerank providers in KnowledgeBaseRoute 2025-10-25 15:22:35 +08:00
Soulter
8ae2a556e4 feat: remove tips from knowledge base creation form and add persistent hints for field modifications 2025-10-25 15:06:07 +08:00
lxfight
4188deb386 fix: 简化日志错误信息格式 2025-10-25 14:13:23 +08:00
lxfight
82cf4ed909 fix: 使用ruff格式化文件代码 2025-10-25 14:10:26 +08:00
lxfight
88fc437abc feat: 优化知识库选择界面,添加自定义滚动条样式 2025-10-25 13:59:09 +08:00
lxfight
57f868cab1 Merge branch 'feature/knowledge-base' of https://github.com/lxfight/AstrBot into feature/knowledge-base 2025-10-25 13:53:03 +08:00
lxfight
6cb5527894 feat: 添加会话知识库配置的 API 接口,支持获取、设置和删除会话配置,优化知识库选择界面 2025-10-25 13:52:57 +08:00
Soulter
016783a1e5 feat: implement RecursiveCharacterChunker and update KnowledgeBaseManager to use it 2025-10-25 13:46:06 +08:00
lxfight
594ccff9c8 fix: 添加数据库连接检查和知识库终止功能,增强错误处理和清理逻辑,修复知识库无法删除的问题 2025-10-25 11:56:37 +08:00
Soulter
30792f0584 Merge pull request #114 from lxfight/lwl-dev/knowledge-base
refactor: 知识库优化
2025-10-25 00:42:16 +08:00
Soulter
8f021eb35a feat: refactor document storage to use SQLModel and enhance database operations 2025-10-24 23:17:37 +08:00
Soulter
1969abc340 feat: add route for legacy knowledge base and update UI with banner suggestion 2025-10-24 22:01:55 +08:00
Soulter
b1b53ab983 Merge remote-tracking branch 'origin/master' into lwl-dev/knowledge-base 2025-10-24 21:48:47 +08:00
Soulter
9b5af23982 feat: remove beta label from knowledge base navigation and adjust margin in KBList component 2025-10-24 21:46:53 +08:00
Soulter
4cedc6d3c8 feat: add t-SNE visualization for FAISS index and enhance knowledge base retrieval with debug mode 2025-10-24 21:22:46 +08:00
Soulter
4e9cce76da feat: add timing logs for dense and sparse retrieval processes and adjust top K results in sparse retriever 2025-10-24 17:51:30 +08:00
Soulter
9b004f3d2f feat: update document retrieval to include limit and offset parameters 2025-10-24 17:38:22 +08:00
Soulter
9430e3090d feat: add progress callback for document upload and enhance upload progress tracking 2025-10-24 17:13:44 +08:00
Soulter
ba44f9117b feat: enhance document upload process with batch settings and improved chunk handling 2025-10-24 16:37:37 +08:00
Soulter
eb56710a72 feat: add chunk size, overlap, and top K parameters to knowledge base response 2025-10-24 15:10:47 +08:00
Soulter
38e3f27899 feat: update knowledge base retrieval configuration and UI adjustments 2025-10-24 15:06:07 +08:00
Soulter
3c58d96db5 feat: add configuration for final knowledge base retrieval count and update related components 2025-10-24 14:45:07 +08:00
Soulter
a6be0cc135 feat: refresh knowledge base and document after uploading a document 2025-10-24 14:28:27 +08:00
Soulter
a53510bc41 refactor: comment out file path handling in KBHelper and search input in DocumentDetail 2025-10-24 14:27:01 +08:00
Soulter
1fd482e899 feat: update chunk deletion to include document ID and refresh metadata 2025-10-24 14:18:32 +08:00
Soulter
2f130ba009 feat: delete chunk and delete document 2025-10-24 13:59:17 +08:00
Soulter
e6d9db9395 feat: disable embedding provider selection in settings tab 2025-10-24 12:53:59 +08:00
Soulter
e0ac743cdb perf: remove rerank functionality from settings tab and related form data 2025-10-24 12:13:51 +08:00
Soulter
b0d3fc11f0 feat: remove sessions tab and related components from knowledge base detail view 2025-10-24 00:48:17 +08:00
Soulter
7e0a50fbf2 feat: enhance knowledge base retrieval with chunk metadata and pagination support; remove unused chunk model 2025-10-24 00:44:40 +08:00
Soulter
59df244173 improve 2025-10-23 21:20:41 +08:00
Soulter
deb31a02cf docs: Update badge links in README.md 2025-10-23 09:53:54 +08:00
Soulter
e3aa1315ae stage 2025-10-23 00:31:15 +08:00
Soulter
65bc5efa19 feat: 集成知识库管理器,优化知识库上下文注入流程,移除冗余代码 2025-10-22 21:59:00 +08:00
Dt8333
abc4bc24b4 fix(dashboard): webchat input textarea is disabled when session controller is active
Removed the disable attribute of Input in isConvRunning. Added an activeSSE counter to correctly determine the current session state and prevent new input from causing interface display errors during session_waiter execution. Set isStreaming after streaming input ends to restore the text box.

#3037 #2892
2025-10-22 20:32:40 +08:00
Soulter
5df3f06f83 fix: persona information is not appearing in the PersonaForm when editing 2025-10-22 17:09:21 +08:00
Soulter
0e1de82bd7 fix: correct indentation in pre-commit config for pyupgrade hook 2025-10-22 17:08:54 +08:00
Soulter
f31e41b3f1 docs: update readme 2025-10-22 13:10:44 +08:00
LIghtJUNction
fe8d2718c4 新增pyupgrade钩子
代码风格统一化
2025-10-21 11:17:20 +08:00
Soulter
8afefada0a fix: image_caption btn 2025-10-21 11:07:39 +08:00
LIghtJUNction
745e1c37c0 Add ruff-check hook to pre-commit config
跟随官方推荐
2025-10-21 11:07:00 +08:00
LIghtJUNction
fdb5988cec 更新 .pre-commit-config.yaml 2025-10-21 11:02:30 +08:00
Soulter
36ffcf3cc3 fix: typing error 2025-10-21 10:56:44 +08:00
Soulter
a0f8f3ae32 style: ruff format 2025-10-21 00:21:42 +08:00
Soulter
130f52f315 chore(monaco-editor): bump monaco-editor version to 0.54.0 2025-10-21 00:18:29 +08:00
lxfight
a05868cc45 feat: 更新知识库管理器以支持重排序模型提供商,调整相关组件的默认配置和提示信息 2025-10-20 22:38:06 +08:00
lxfight
2fc77aed15 feat: 添加知识库检索功能,支持根据知识库 ID 列出相关会话;更新相关界面和国际化文本 2025-10-20 22:23:35 +08:00
lxfight
c56edb4da6 feat: 添加知识库配置功能,支持会话管理中的知识库选择与设置 2025-10-20 21:46:39 +08:00
Soulter
6672190760 feat: add star count display and fetch functionality in sidebar 2025-10-20 18:19:21 +08:00
exynos
f122b17097 fix(update): 取消 WebUI 与核心版本对比,消除“webui有新版本!”的误报 (#3106)
* fix(update): 取消 WebUI 与核心版本对比,消除“webui有新版本!”的误报

不再比较 dv 与核心版本

* fix(update): 保留dv逻辑,新增installed标识以避免误报

新增安装状态布尔值,保留“dv 是否存在”的信息

* Fix dashboard version update check logic

---------

Co-authored-by: LIghtJUNction <lightjunction.me@gmail.com>
2025-10-20 16:15:42 +08:00
Soulter
2c5f68e696 refactor: 重构创建平台时的流程及一些 UI 优化 (#3102)
* refactor: 支持在平台直接选择配置文件

* add webchat

* feat: 支持新建平台时现场预览、创建和编辑配置文件

* fix: update configuration file descriptions and visibility based on updating mode

* perf: use incremental decoder

* perf: update descriptions

* fix: UI update issues in config file dialog

* fix: update UI elements for better readability and organization

* feat: enhance sidebar navigation with group feature and dynamic resizing

Co-authored-by:  IGCrystal <3811541171@qq.com>

* refactor: persona selector

* perf: 修改部分默认行为

* fix: adjust ExtensionCard layout and improve responsiveness

* refactor: 配置文件绑定消息平台重构为消息平台绑定配文件

* style: add custom styling for v-select selection text

* fix: correct subtitle text in provider.json

* refactor: update conversation management terminology and improve session ID handling

* refactor: add Conversation ID localization and update table header reference

* Update astrbot/core/db/migration/migra_45_to_46.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* style: format logger warning for better readability

* refactor: comment out WebChat configuration for future reference

---------

Co-authored-by: IGCrystal <3811541171@qq.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-20 12:01:06 +08:00
MoonShadow1976
e1ca645a32 feat: 增强工具调用参数处理机制 (#3036)
* feat: 增强工具调用参数处理机制

在工具调用时添加参数过滤功能,只传递函数实际需要的参数
解决问题:https://github.com/AstrBotDevs/AstrBot/issues/2988

* feat: 利用现有工具定义信息处理非期望的参数

不使用`inspect`库,利用现有工具定义信息处理非期望的参数

* ruff format for code

合并结果:
移除了多余参数避免报错,代码执行器可以正常工作。
2025-10-20 02:51:16 +08:00
lxfight
333bf56ddc feat:知识库卡片渲染统计信息。 2025-10-19 22:40:01 +08:00
lxfight
b240594859 feat:添加Beta 版本的知识库管理器前端页面;添加i18n相关文件内容。 2025-10-19 21:55:21 +08:00
lxfight
beccae933f fix:修复KBSessionConfig的导入问题 2025-10-19 21:36:01 +08:00
lxfight
e6aa1d2c54 feat:删除v2版本的知识库前端代码;删除i18n相关文件 2025-10-19 21:16:00 +08:00
magisk317
5e808bab65 fix(platform): prevent 'NoneType' object is not iterable in _outline_chain and set_result (#3103)
Guard against cases where message chain is None during pipeline execution. This change enhances error-resilience for logging and processing message chains.

- Updated AstrMessageEvent._outline_chain to return an empty string when input chain is None
- Updated AstrMessageEvent.set_result to ensure result.chain is always at least an empty list

This prevents TypeError when result.chain or chain is unexpectedly None, improving pipeline stability when handling external plugins or corner cases.

Co-authored-by: engine-labs-app[bot] <140088366+engine-labs-app[bot]@users.noreply.github.com>
Co-authored-by: cto-new[bot] <140088366+cto-new[bot]@users.noreply.github.com>
2025-10-19 20:16:14 +08:00
Dt8333
361d78247b fix(core): 修复人格预设对话的重复注入 (#3088)
备份Context避免供应商适配器移除Context内字段导致将预设会话存入历史。深拷贝人格预设会话防止运行时被意外修改。

#3063
2025-10-19 20:13:57 +08:00
a490077
3550103e45 feat: QQ 官方机器人增加沙盒模式选项,让本地部署能跳过 IP 白名单验证 (#3087)
* QQ官方机器人增加沙箱模式选项,让本地部署能跳过IP白名单验证

* chore: ruff format

---------

Co-authored-by: 郭鹏 <gp@pp052.top>
Co-authored-by: Soulter <905617992@qq.com>
2025-10-19 20:09:08 +08:00
PaloMiku
8b0d4d4de4 feat: 优化 Misskey 适配器的通知和聊天消息处理,改进 @用户提及逻辑 (#3075) 2025-10-19 20:05:55 +08:00
shangxue
dc71c04b67 feat(satori): 添加对合并转发消息功能的支持 (#3050)
* Update satori_event.py

* Update satori_event.py

* Update satori_event.py

* Update satori_adapter.py

* style: format code for better readability in satori_adapter.py and satori_event.py

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-19 20:05:03 +08:00
lxfight
a0254ed817 refactor: 优化知识库管理器和数据库操作的代码格式 2025-10-19 19:36:26 +08:00
lxfight
2563ecf3c5 feat: 实现知识库前端组件和路由
- 实现知识库 V2 主页面和 4 个子面板组件
- 文档管理面板:支持上传、删除、查看文档分块
- 检索测试面板:支持测试知识库检索效果
- 全局设置面板:配置嵌入模型、重排序、检索参数
- 会话配置面板:管理会话与知识库的绑定关系
- 重构 Alkaid 路由为嵌套结构,添加知识库 V2 路由
- 在翻译系统中注册知识库 V2 多语言支持
- 默认进入 Alkaid 时跳转到原生知识库页面
2025-10-19 18:43:58 +08:00
lxfight
c04738d9fe feat: 实现知识库前端界面(英文国际化)
- 添加知识库 V2 完整英文翻译文件
- 包括:主页、文档管理、检索测试、全局设置、会话配置
- 在 Alkaid 导航中添加 "Native Knowledge Base" 入口
- 区分 "Native Knowledge Base" 和 "Knowledge Base (Plugin)"
2025-10-19 18:43:35 +08:00
lxfight
1266b4d086 feat: 实现知识库前端界面(中文国际化)
- 添加知识库 V2 完整中文翻译文件
- 包括:主页、文档管理、检索测试、全局设置、会话配置
- 在 Alkaid 导航中添加"原生知识库"入口
- 区分"原生知识库"和"知识库(插件)"两个入口
2025-10-19 18:42:43 +08:00
lxfight
99cf0a1522 feat: 添加知识库 Dashboard API 路由
- 实现知识库管理 API(创建、删除、列表、更新)
- 实现文档管理 API(上传、删除、列表、分块信息)
- 实现知识库检索测试 API(支持调试和验证)
- 实现会话配置 API(绑定/解绑知识库、配置检索参数)
- 实现全局配置 API(启用/禁用、模型选择、检索参数)
- 在 Dashboard 服务器中注册知识库路由
2025-10-19 18:41:54 +08:00
lxfight
98a75e923d feat: 集成知识库到核心生命周期和消息流水线
- 在 AstrBotCoreLifecycle 中初始化知识库管理器
- 将知识库注入器添加到消息处理上下文
- 在消息流水线中添加 KBEnhanceStage(知识库增强阶段)
- 实现会话删除时的知识库配置级联清理机制
- 添加会话管理器的回调注册机制,支持零侵入扩展
2025-10-19 18:41:34 +08:00
lxfight
ad96d676e6 feat: 实现知识库核心后端模块
- 实现完整的知识库数据模型(知识库、文档、文档块、会话配置)
- 实现基于 SQLite 的向量数据库存储和检索
- 实现文档解析器(PDF、TXT)和固定大小分块器
- 实现混合检索系统(密集向量检索 + BM25 稀疏检索 + RRF 融合)
- 实现知识库生命周期管理和消息注入器
- 支持会话级别的知识库配置和关联
2025-10-19 18:40:55 +08:00
lxfight
79333bbc35 feat: 添加知识库核心依赖和配置
- 添加 pypdf、aiofiles、rank-bm25 依赖包支持文档解析和检索
- 在 default.py 中添加知识库完整配置项
- 配置包括嵌入模型、重排序、存储路径、分块策略、检索参数等
- 默认禁用知识库功能,需用户主动启用
2025-10-19 18:39:10 +08:00
Soulter
5c5b0f4fde fix: 修复未安装知识库插件时的错误引导 2025-10-18 10:36:11 +08:00
Dt8333
ed6cdfedbb fix: 修复 dashboard 的部分编译错误 (#3041)
* chore(dashboard): adding missing dependency

* fix(dashboard): 修复vertical-header中 $router 类型错误
2025-10-16 10:32:08 +08:00
PaloMiku
23f13ef05f feat:Misskey 适配器支持文件上传、投票内容感知功能和重构部分代码 (#2986)
* feat: 为 Misskey 适配器修正一些问题,添加投票信息读取支持

* feat: 增强 Misskey 平台适配器,添加随机重连延迟和通道重新订阅功能

* feat: 添加文件上传功能并优化消息发送接口,支持同时发送文件和文本

* feat: 增强文件上传功能,支持 MIME 类型检测和外部 URL 回退

* feat: 增加 Misskey 文件上传功能开关,支持配置文件上传启用与并发限制

* feat: 添加 Misskey 文件上传目标文件夹配置,支持将文件上传到指定文件夹

* feat: 优化 Misskey 平台适配器,增强文件上传和消息发送功能,支持更多可选字段

* feat: 代码优化结构与功能

* feat(misskey): 增强消息发送逻辑和工具函数

- 重构了 `misskey_event.py` 中的 `send` 方法,使用新的适配器方法 `send_by_session`,以改进消息处理(包括文件上传)。
- 添加了详细的日志记录,以提高消息发送过程的可追溯性。
- 在 `misskey_utils.py` 中引入了 `FileIDExtractor` 和 `MessagePayloadBuilder` 类,以简化文件 ID 提取和消息载荷构建。
- 在 `misskey_utils.py` 中实现了 MIME 类型检测和文件扩展名解析,以支持多种文件上传。
- 增强了 `resolve_component_url_or_path`,以更好地处理不同类型的组件上传文件。
- 在 `upload_local_with_retries` 中添加了重试逻辑,以优雅地处理不允许的文件类型。

* feat(misskey): 限制文件上传并发数,优化消息处理逻辑

* feat(misskey): ruff formatted

* feat: 大幅优化 misskey 文件上传逻辑,简化上传流程并增强可见性解析

* feat(misskey): 移除 Url上传方式,精简日志

* fix(misskey): 修复错把URL文件当本地文件上传的问题,明确处理 URL 和本地文件的方式

* fix(misskey): 修复 session_id 解析逻辑,确保与 user_cache 键格式匹配

* perf: streaming the file with a file object in FormData to reduce peak memory usage.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* style: format debug log message for local file upload in MisskeyAPI

* refactor: remove unnecessary thread executor for reading file bytes in MisskeyAPI

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-10-16 10:27:04 +08:00
Soulter
f9c59d9706 docs: fix typo 2025-10-16 09:17:09 +08:00
Soulter
e1cec42227 chore: add Node.js setup step in CI workflow 2025-10-15 23:32:53 +08:00
Soulter
8d79c50d53 chore: update CI workflow to use pnpm for package management 2025-10-15 23:12:38 +08:00
Soulter
d77830b97f feat: add markdown-it type definitions as a dev dependency 2025-10-15 23:01:38 +08:00
Soulter
394540f689 docs: Update support status for various platforms 2025-10-15 18:48:25 +08:00
Soulter
7d776e0ce2 chore: bump version to 4.3.5 2025-10-15 12:19:26 +08:00
Soulter
17df1692b9 fix: 修复 /alter_cmd reset scene <num> xxx 不可用的问题 2025-10-15 12:16:13 +08:00
Soulter
9ab652641d feat: 支持配置工具调用超时时间并适配 ModelScope 的 MCP Server 配置 (#3039)
* feat: 支持配置工具调用超时时间并适配 ModelScope 的 MCP Server 配置。

closes: #2939

* fix: Remove unnecessary blank lines in _quick_test_mcp_connection function
2025-10-15 12:06:57 +08:00
shangxue
9119f7166f feat: satori 适配器支持 video、reply 消息类型 (#3035)
* Update satori_event.py

* style: format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-15 10:45:35 +08:00
Soulter
da7d9d8eb9 feat: Add tutorial link for wecom_ai_bot platform 2025-10-15 10:42:31 +08:00
Soulter
80fccc90b7 feat: 支持接入企业微信智能机器人平台 (#3034)
* stage

* stage

* feat: 支持图片收发

* feat: add support for wecom_ai_bot in getPlatformIcon function
2025-10-14 23:20:56 +08:00
Soulter
dcebc70f1a chore: Add new auto-assign users to configuration 2025-10-14 12:16:22 +08:00
dependabot[bot]
259e7bc322 chore(deps): bump github/codeql-action in the github-actions group (#3032)
Bumps the github-actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action).


Updates `github/codeql-action` from 3 to 4
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/github/codeql-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: github/codeql-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-14 09:35:57 +08:00
Soulter
37bdb6c6f6 feat: 内置网页搜索功能支持接入百度 AI 搜索 (#3031)
* feat: 内置网页搜索功能支持接入百度 AI 搜索

* fix: 修正配置文件中的拼写错误,更新为正确的键名

* Fix Baidu AI Search initialization logic
2025-10-14 09:35:34 +08:00
Soulter
dc71afdd3f docs: Revise README for clarity and updated support info
Updated README.md to improve clarity and fix formatting issues. Removed outdated developer group information and added support details for new platforms and services.
2025-10-14 09:13:54 +08:00
Soulter
44638108d0 docs: readme 2025-10-14 08:53:23 +08:00
RC-CHN
93fcac498c feat: 添加并优化服务提供商独立测试功能 (#3024)
* feat: 添加并优化服务提供商独立测试功能

* feat: add small size to action buttons in ItemCard and ProviderPage for better UI consistency

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-13 13:03:20 +08:00
Soulter
79e2743aac chore: bump version to 4.3.3 2025-10-12 11:42:18 +08:00
anka
5e9c7cdd91 fix: 当没有填写 api key 时,设置为空字符串 (#2834)
* fix: 修复空key导致的无法创建Provider对象的问题

* style: format code

* Update astrbot/core/provider/provider.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-10-12 10:50:01 +08:00
Dt8333
6f73e5087d feat(core): 在新对话中重用先前的对话人格设置 (#3005)
* feat(core): reuse persona conf in new conversation

#2985

* refactor(core): simplify persona retrieval logic

* style: code format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-12 10:42:35 +08:00
Yaron
8c120b020e fix: 修复阿里云百炼平台 TTS 下接入 CosyVoice V2, Qwen TTS 生成报错的问题 (#2964)
* fix: 修复了CosyVoice V2,Qwen TTS生成报错的问题。Fixed compatability problems with CosyVoice V2, Qwen TTS.

* fix: 将urlopen的同步请求替换为aiohttp的异步请求以下载音频

* fix: cozyvoice 报错显示

* fix: 添加阿里云百炼 TTS API Key 获取提示信息

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-12 01:03:06 +08:00
Dt8333
12fc6f9d38 fix(LTM): fix LTM not removed when removing conversation (#3002)
#2983
2025-10-12 00:16:42 +08:00
Dt8333
a6e8483b4c fix: 修复session-management中人格错误的显示为默认人格的问题 (#3000)
* fix: 修复session-management中人格错误的显示为默认人格的问题

#2985

* refactor: 使用命名表达式简化赋值和条件

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* style: format edited code with ruff

format code edited by sourcery-ai

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-12 00:12:04 +08:00
Soulter
7191d28ada fix: 启动了 TTS 但未配置 TTS 模型时,At 和 Reply 发送人无效
fixes: #2996
2025-10-10 12:11:03 +08:00
Soulter
e6b5e3d282 feat: tokenpony provider 2025-10-09 16:00:31 +08:00
ctrlkk
1413d6b5fe fix: 让事件钩子被暂停时跳出循环,而不是继续执行 (#2989) 2025-10-09 15:01:45 +08:00
ctrlkk
dcd8a1094c feat: 优化 SQLite 参数配置,对话和会话管理增加输入防抖机制 (#2969)
* feat: 优化 SQLite 数据库初始化设置并增强会话搜索功能,会话管理增加输入防抖

* fix: adjust SQLite cache and mmap size

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-06 17:13:53 +08:00
Futureppo
e64b31b9ba fix: Correct default modalities for DeepSeek provider (#2963)
* 更新 package.json

* 更新 ExtensionPage.vue

* fix(provider): Correct default modalities for DeepSeek provider
2025-10-06 16:30:05 +08:00
Dt8333
080f347511 feat: clean browser cache after update (#2958)
* feat: clean browser cache after update

* fix: move const to module

* fix: remove self prefix (a stupid mistake)
2025-10-06 16:29:18 +08:00
Dt8333
eaaff4298d fix(Python-Interpreter): fix incorrect file read method (#2970)
fix getting file by property(Sync) in an async handler

#2960
2025-10-06 16:12:05 +08:00
Soulter
dd5a02e8ef chore: bump version to 4.3.2 2025-10-05 01:01:13 +08:00
Soulter
3211ec57ee fix: handle Google search initialization and errors gracefully 2025-10-05 00:55:47 +08:00
Soulter
6796afdaee fix: googlesearch 2025-10-05 00:54:24 +08:00
Soulter
cc6fe57773 fix: on_tool_end无法获得工具返回的结果 (#2956)
fixes: #2940
2025-10-05 00:37:51 +08:00
Soulter
1dfc831938 fix: 修复 reset 没有清除群聊上下文感知数据的问题 (#2954) 2025-10-05 00:05:42 +08:00
Futureppo
cafeda4abf feat: 为插件市场的搜索增加拼音与首字母搜索功能 (#2936)
* 更新 package.json

* 更新 ExtensionPage.vue
2025-10-03 09:42:57 +08:00
Soulter
d951b99718 fix: 发送阶段将 Plain 为空的消息段移除 2025-10-03 00:45:07 +08:00
Soulter
0ad87209e5 chore: bump version to 4.3.1 2025-10-02 17:25:09 +08:00
Soulter
1b50c5404d fix: enhance knowledge base plugin status check to handle empty data response 2025-10-02 17:25:00 +08:00
Soulter
3007f67cab fix: update Dockerfile to remove npm installation and streamline package setup
closes: #2284
2025-10-02 16:59:11 +08:00
Soulter
ee08659f01 chore: bump version to 4.3.0 2025-10-02 16:37:54 +08:00
Soulter
baf5ad0fab fix: 修复接入智谱提供商后,工具调用无限循环的问题,并停止支持 glm-4v-flash (#2931)
fixes: #2912
2025-10-02 16:03:24 +08:00
kterna
8bdd748aec feat: 支持注册消息平台适配器的 logo (#2109)
* feat: 添加平台适配器 logo 支持

* 优化平台logo注册逻辑,增加缓存机制并支持并行处理

* 去除判断绝对路径

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-02 14:36:15 +08:00
Soulter
cef0c22f52 feat: update prompt prefix handling to support placeholder replacement 2025-10-02 14:20:52 +08:00
Soulter
13d3fc5cfe fix: fix type checking error and op, deop, wl, dwl command 2025-10-02 00:18:12 +08:00
Soulter
b91141e2be fix: add plugin activation check and corresponding messages in Knowledge Base 2025-10-01 22:14:03 +08:00
Soulter
f8a4b54165 fix: 修复插件指令注解为联合类型时处理异常的问题 (#2925)
* fix: 修复插件指令注解为联合类型时处理异常的问题

* fix: 修复参数类型检查以支持 typing.Union

* Update astrbot/core/star/filter/command.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/star/filter/command.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: 修复参数类型检查以支持 typing.Union 的处理逻辑

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-01 21:46:49 +08:00
Soulter
afe007ca0b refactor: 优化 packages/astrbot 内置插件的代码结构以提高可维护性和可读性 (#2924)
* refactor: code structure for improved readability and maintainability

* style: ruff format

* Update packages/astrbot/commands/provider.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update packages/astrbot/commands/persona.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update packages/astrbot/commands/llm.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update packages/astrbot/commands/conversation.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* fix: improve error handling message formatting in key switching

* fix: update LLM command to use safe get for provider settings

* feat: implement ProcessLLMRequest class for handling LLM requests and persona injection

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-01 21:29:15 +08:00
Soulter
8a9a044f95 fix: 修复注册指令组指令时的 Pyright 类型检查提示 (#2923) 2025-10-01 20:03:04 +08:00
u0_ani-nya.com
5eaf03e227 perf: 对于 Telegram 群聊,将回复机器人的消息视为唤醒机器人 (#2926)
* reply as at for tg

Add handling for bot replies in group messages.

* style: type checking and ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-01 19:04:37 +08:00
Seayon
a8437d9331 feat: 支持在 Telegram 和飞书下请求 LLM 前预表态功能 (#2737)
*  feat(platform): 为 Telegram 和飞书添加消息表情回应功能

支持在收到命令时自动添加表情回应,提升用户交互体验
新增平台特异配置项,允许自定义启用状态和表情列表

* Update astrbot/core/platform/astr_message_event.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* style: ruff format

* fix: 优化平台特异配置的预回应表情处理逻辑

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-30 17:29:34 +08:00
晴空
e0392fa98b fix: 用 mi-googlesearch-python 库代替失效的 googlesearch-python 库 (#2909)
* googlesearch-python库失效,用mi-googlesearch-python库平替,恢复谷歌搜索

* Update googlesearch-python dependency version
2025-09-29 12:54:16 +08:00
ctrlkk
68ff8951de feat: 添加分页和搜索功能以获取会话列表,优化前端与后端的数据交互 (#2906)
* feat: 添加分页和搜索功能以获取会话列表,优化前端与后端的数据交互

* fix: 修复会话计数显示,使用总项数替代会话数组长度

* fix: 将参数类型和名称与实现内容匹配。

* perf: convert for loop into list comprehension

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* fix: type checking error

* fix: 优化 persona_id 的获取逻辑

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-28 23:25:30 +08:00
KroMiose
9c6b31e71c Update README.md (#2904) 2025-09-28 14:50:02 +08:00
Soulter
50f74f5ba2 fix: 修复"开启 TTS 时同时输出语音和文字内容"功能不可用的问题 (#2900)
fixes: #2844
2025-09-28 10:48:57 +08:00
Soulter
b9de2aef60 chore: bump version to 4.2.1 2025-09-27 23:36:25 +08:00
Soulter
7a47598538 fix: 修复指令无法使用的问题
fixes: #2897
2025-09-27 23:35:35 +08:00
Soulter
3c8c28ebd5 chore: bump version to 4.2.0 2025-09-27 20:45:50 +08:00
Soulter
524285f767 feat: add cancel button with localized text to AddNewPlatform and update close button in AddNewProvider
fixes: #2889
2025-09-27 20:41:45 +08:00
Soulter
c2a34475f1 feat: 支持删除指定会话以及部分会话管理优化 (#2895)
* feat: add toast notification system with snackbar component

* feat: add session deletion functionality

* feat: support batch operations for updating session persona, provider, LLM, and TTS statuses

fix: #2263

* feat: 修复对话状态关闭,删除对话管理库会导致对话无法恢复

fixes: #2309
2025-09-27 20:36:30 +08:00
Soulter
a69195a02b fix: webchat streaming queue interrupted after user closing tab (#2892)
* feat: add toast notification system with snackbar component

* feat: enhance chat functionality with conversation running state and notifications

* fix: update bot message avatar rendering during streaming

* feat: implement conversation tracking context manager for webchat

* fix: update conversation tracking to remove conversation ID on exit
2025-09-27 17:57:12 +08:00
RC-CHN
19d7438499 fix: unit tests (#2760)
* fix:修复了main和plugin_manager部分单元测试

* fix: 修复了dashboard部分测试

* remove: 删除暂无用的配置测试脚本

* perf:拆分插件增查删改为独立的单元测试

* refactor: 重构插件管理器测试,使用临时环境隔离测试实例

* test: 增加对仪表板文件检查的单元测试,涵盖不同情况

* style: format code

* remove: 删除未使用的导入语句

* delete: remove unused test file for pipeline

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:43:04 +08:00
anka
ccb380ce06 feat: 支持接入 Coze (#2858)
* feat: 适配 coze 供应商
1. 支持文件上传
2. 支持多模态
3. 支持流式传输
4. 支持 API 端的上下文保存历史记录
5. 支持类似 dify 的 forget 接口

* style: format code

* fix: type checking error

* fix: 修复:
1. 使用coze api端的上下文时, 现在不会重复传递上下文
2. 使用 AstrBot 的上下文时, 正确处理其中的图片信息
3. 上传图片时, 提供一个非持久化的缓存避免重复上传(在解析上下文并将文件转化为file_id传递给coze api时, 如果没有缓存会导致很多的网络资源浪费)
4. 修复reset等指令不能正确重置上下文的问题

* fix: 移除某些地方多余的针对 dify 的断言, 以兼容 Coze

* style: 修改配置项显示/webchat平台对于非预期的类型的处理

* fix: 让conversation_id放到请求中正确的位置

* refactor: extract coze api client

* refactor: improve image processing logic in ProviderCoze

* chore: remove file ext guessing

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:23:29 +08:00
Ding Jiatong
a35c439bbd fix: 使用增量解码器修复 Dify 流式返回结果偶现的解码错误 (#2888)
* fix: 修复linux下utf-8解码错误的问题

* feat: use incremental decoder

* fix: add type hint for response parameter in _stream_sse and refactor file upload method

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 23:04:58 +08:00
Soulter
09d1f96603 fix: 修复 /alter_cmd 指令无法控制指令组、子指令组和子指令组下子指令的问题 (#2873)
* fix: revert changes in command_group.py at 782c036 to fix command group permission check

* fix: 不传递 GroupCommand handler

* perf: alter_cmd 指令支持对子指令、指令组进行配置

* chore: remove test commands and subcommands from test_group

* chore: add cache for complete command names list in CommandFilter and CommandGroupFilter

---------

Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 14:16:50 +08:00
鸦羽
26aa18d980 Merge pull request #2881 from Raven95676/fix/2879
fix: add missing id field
2025-09-26 11:31:28 +08:00
Raven95676
d10b542797 chore: format 2025-09-26 11:05:32 +08:00
Raven95676
ce4e4fb8dd fix: add missing id field 2025-09-26 10:59:11 +08:00
Soulter
8f4a31cf8c chore: bump version to 4.1.7 2025-09-23 22:16:36 +08:00
Soulter
23549f13d6 Feature: 支持批量删除对话历史 (#2859)
* feat: 支持批量删除对话

closes: #2784

* feat: 添加加载状态禁用功能,优化用户交互体验
2025-09-23 22:10:56 +08:00
Soulter
869d11f9a6 perf: 优化验证配置时的性能,移除配置隐式类型转换
fixes: #2646
2025-09-23 21:04:14 +08:00
Soulter
02e73b82ee fix: 修复无法打开更新对话框的问题 2025-09-23 20:29:10 +08:00
Soulter
f85f87f545 feat: WebChat 支持手动填写模型名
closes: #2830
2025-09-23 15:32:54 +08:00
Soulter
1fff5713f3 refactor: 解耦 PlatformPage 和 ProviderPage 的部分组件 2025-09-23 15:32:54 +08:00
Soulter
8453ec36f0 docs: Revise links for documentation and blog in README
Updated links in the README for documentation and blog.
2025-09-23 14:12:05 +08:00
Soulter
d5b3ce8424 fix: update download_dashboard to log specific dashboard release URLs 2025-09-23 13:10:33 +08:00
Soulter
80cbbfa5ca chore: bump version to 4.1.6 2025-09-23 13:02:06 +08:00
Soulter
9177bb660f fix: improve error handling in run_agent for streaming responses 2025-09-23 10:34:24 +08:00
Soulter
a3df39a01a perf: unified button styles
closes: #2748
2025-09-23 10:27:52 +08:00
Soulter
25dce05cbb refactor: improve webchat UI (#2853) 2025-09-23 10:19:26 +08:00
Soulter
1542ea3e03 fix: context.get_provider_by_id issue 2025-09-22 17:22:50 +08:00
Soulter
6084abbcfe feat: add user_id search capability in get_filtered_conversations 2025-09-21 22:45:55 +08:00
Soulter
ed19b63914 chore: bump version to v4.1.5 2025-09-21 21:47:14 +08:00
Soulter
4efeb85296 chore: remove uv.lock file 2025-09-21 21:47:06 +08:00
shangxue
fc76665615 feat: Satori适配器引用消息无法正确识别 (#2686)
* Update PlatformPage.vue

* Update PlatformPage.vue

* Update PlatformPage.vue

* Update satori_adapter.py

* Update satori_event.py

* Update default.py

* Update satori_adapter.py

* Update satori_adapter.py

* style: format code

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-21 21:45:35 +08:00
Soulter
3a044bb71a fix: 修复 Telegram 下流式传输时,第一次输出的内容会被覆盖掉的问题 (#2838)
fixes: #2481
2025-09-21 21:24:47 +08:00
Soulter
cddd606562 perf: 优化 ExtensionPage 2025-09-21 21:10:03 +08:00
Soulter
7a5bc51c11 fix: 识别引用消息的图片时优先使用默认图片转述提供商 (#2836)
* fix: 识别引用消息的图片时优先使用默认图片转述提供商

closes: #2821

* fix: 添加日志记录以处理未找到图片标题提供者的情况

* style: format code
2025-09-21 20:55:32 +08:00
Soulter
9f939b4b6f fix: 修复对话管理页面的关键词搜索功能失效的问题并优化一些 UI 样式 (#2837)
* fix: 修复对话管理页面的关键词搜索功能失效的问题并优化一些 UI 样式

fixes: #2782

* style: format code

* fix: remove debug print statements from conversation retrieval methods
2025-09-21 20:55:15 +08:00
Soulter
80a86f5b1b fix: 修复 astrbot.core.star 等包下的 type checking error (#2787)
* fix: 修复 astrbot.core.star 等包下的 type checking error

* refactor: improve type checking and annotations

* chore: ruff format
2025-09-21 18:10:04 +08:00
yitaikarma
a0ce1855ab fix: 优化统计页内存占用和消息数据趋势的样式 (#2826)
* fix: 调整统计页内存占用和消息趋势分析的布局,优化响应式显示

* fix: 隐藏增长率为零时的趋势图标
2025-09-21 17:06:47 +08:00
anka
a4b43b884a fix: 修复aiocqhttp适配器at会获取群昵称而消息不会获取的逻辑不一致 (#2769)
* fix: 修复at会获取群昵称而消息不会获取的逻辑不一致

* style: format code
2025-09-19 13:04:51 +08:00
PaloMiku
824c0f6667 feat: 新增 Misskey 平台适配器 (#2774)
* feat: add Misskey platform adapter

* fix: 修复 Misskey 配置项的大小写问题

* feat: 添加消息链序列化功能和可见性解析逻辑

* chore: 删除损坏的 Misskey 平台适配器工具函数文件

* docs: 更新 Misskey 消息适配器设置描述信息

* feat: Misskey 单用户连续上下文对话支持

* feat: 为 Astrbot 添加 Misskey 平台适配器的 ID 配置

* feat: 重构 Misskey 平台适配器,提取通用工具函数并优化消息处理逻辑

* refactor: 清理 Misskey 平台适配器和 API 代码,移除冗余注释

* fix: 修复了使用中和使用者反馈的多个问题

* fix: 修改提及格式,确保提及在新行开始,提升帖子美观和易读性。

* feat: 添加默认可见性和本地仅限设置,优化 Misskey 平台适配器的配置

* fix: 更新 Misskey 平台适配器配置,使用前缀以防止和其他适配器未来可能的冲突问题

* chore: rename 'misskey' to 'Misskey' in config

* feat: Misskey 适配器添加聊天消息响应功能,重构接收和发送逻辑为 Websockets 处理

* fix: 增强 Misskey WebSocket 消息日志输出

* refactor: 优化 Misskey 适配器的消息处理和日志输出

* fix: 增强 Misskey WebSocket 重连接逻辑

* feat: 增强 Misskey 适配器的消息处理,支持房间消息和相关功能,重构通用函数,清理代码重复冗余

* fix: 不屏蔽唤醒前缀对默认 LLM 的唤醒

* fix: 透传所有的群聊消息事件

* fix: 修复 message_type

* perf: 实现 send_streaming 以支援流式请求

* docs(README): update README.md

* fix: super().send(message) 被忽略

* fix: 修正 session 结构

: 作为分隔符可能会导致 umo 组装出现问题

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-18 23:34:41 +08:00
Soulter
a030fe8491 feat: add audioop-lts dependencies (#2809)
pydub needs audioop as a requirement but this builtin package has been removed in 3.13
2025-09-18 23:32:04 +08:00
Soulter
3a9429e8ef fix: on_tool_end hook unavailable 2025-09-17 15:48:57 +08:00
anka
c4eb1ab748 chore: bump version to 4.1.4 2025-09-16 20:09:11 +08:00
anka
29ed19d600 Merge pull request #2783 from AstrBotDevs/revert-2778-fix-handler-type
Revert "fix: parameter type/default handling in CommandFilter"
2025-09-16 20:01:23 +08:00
anka
0cc65513a5 Revert "fix: parameter type/default handling in CommandFilter" 2025-09-16 20:01:05 +08:00
Soulter
debc048659 chore: bump version to 4.1.3 2025-09-16 13:16:21 +08:00
邹永赫
92f5c918dd Merge pull request #2778 from MliKiowa/fix-handler-type
fix: parameter type/default handling in CommandFilter
2025-09-16 13:43:53 +09:00
手瓜一十雪
9519f1e8e2 fix: parameter type/default handling in CommandFilter
Adjusts logic to prioritize type annotations over default values when setting handler_params in CommandFilter. This ensures that parameter types are correctly inferred when available.
2025-09-16 11:49:27 +08:00
Soulter
a8f874bf05 fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 (#2757) 2025-09-16 10:45:39 +08:00
anka
9d9917e45b feat: 增加群名称识别到 system prompt, 并提供相应的配置 (#2770)
* feat🤖: 增加群名称识别到system prompt, 并提供相应的配置

* feat: 优化实现方式, 重构AstrBotMessage, 向后兼容

* style: format
2025-09-16 10:23:08 +08:00
Soulter
91ee0a870d fix: handle image value correctly for mcp BlobResourceContents (#2753) 2025-09-16 08:22:18 +08:00
dependabot[bot]
6cbbffc5a9 chore(deps): bump the github-actions group with 2 updates (#2771)
Bumps the github-actions group with 2 updates: [actions/checkout](https://github.com/actions/checkout) and [actions/setup-python](https://github.com/actions/setup-python).


Updates `actions/checkout` from 4 to 5
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

Updates `actions/setup-python` from 5 to 6
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-16 08:19:31 +08:00
Yokami
8f26fd34d1 feat: add copy button for service providers (#2767) 2025-09-15 22:17:00 +08:00
Soulter
fda655f6d7 fix: 修复配置默认 TTS 或者 STT 模型之后仍无法生效的问题 (#2758)
fixes: #2731
2025-09-15 22:08:40 +08:00
266 changed files with 25675 additions and 12311 deletions

View File

@@ -16,7 +16,7 @@ body:
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON 现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
不熟悉 JSON ?可以从 [此处](https://plugins.astrbot.app/submit) 生成 JSON ,生成后记得复制粘贴过来.
- type: textarea
id: plugin-info

View File

@@ -11,6 +11,8 @@ reviewers:
- Larch-C
- anka-afk
- advent259141
- Fridemn
- LIghtJUNction
# - zouyonghe
# A number of reviewers added to the pull request

View File

@@ -12,10 +12,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'

View File

@@ -60,7 +60,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -88,6 +88,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4
with:
category: "/language:${{matrix.language}}"

View File

@@ -13,11 +13,18 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v5
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 'latest'
- name: npm install, build
run: |
cd dashboard
npm install
npm run build
npm install pnpm -g
pnpm install
pnpm i --save-dev @types/markdown-it
pnpm run build
- name: Inject Commit SHA
id: get_sha

2
.gitignore vendored
View File

@@ -31,3 +31,5 @@ packages/python_interpreter/workplace
.idea
pytest.ini
.astrbot
uv.lock

View File

@@ -6,8 +6,20 @@ ci:
autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]

View File

@@ -4,8 +4,6 @@ WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
nodejs \
npm \
gcc \
build-essential \
python3-dev \
@@ -13,23 +11,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libssl-dev \
ca-certificates \
bash \
ffmpeg \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y curl gnupg && \
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
# 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
RUN uv pip install socksio uv pilk --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD [ "python", "main.py" ]

152
README.md
View File

@@ -1,28 +1,38 @@
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
<div align="center">
<br>
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路线图</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
## 主要功能
@@ -34,7 +44,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
## 部署方式
#### Docker 部署
#### Docker 部署(推荐 🥳)
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
@@ -62,7 +72,7 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### Windows 一键安装器部署
@@ -100,7 +110,6 @@ uv run main.py
- 5 群822130018
- 6 群753075035
- 开发者群975206796
- 开发者群备份295657329
### Telegram 群组
@@ -110,49 +119,83 @@ uv run main.py
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ⚡ 消息平台支持情况
**官方维护**
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方机器人接口) | ✔ |
| QQ(官方平台) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企业微信 | ✔ |
| 企微应用 | ✔ |
| 企微智能机器人 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
| [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) | ✔ |
## ⚡ 提供商支持情况
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | 文本生成 | |
| Google Gemini | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
**大模型服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
## ❤️ 贡献
@@ -167,12 +210,11 @@ uv run main.py
AstrBot 使用 `ruff` 进行代码格式化和检查。
```bash
git clone https://github.com/Soulter/AstrBot
git clone https://github.com/AstrBotDevs/AstrBot
pip install pre-commit
pre-commit install
```
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
@@ -185,29 +227,17 @@ pre-commit install
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
另外,一些同类型其他的活跃开源 Bot 项目:
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
## ⭐ Star History
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
</div>
</details>
_私は、高性能ですから!_

View File

@@ -10,16 +10,16 @@ _✨ Easy-to-use Multi-platform LLM Chatbot & Development Framework ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot)
<a href="https://astrbot.app/">Documentation</a>
<a href="https://github.com/Soulter/AstrBot/issues">Issue Tracking</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracking</a>
</div>
AstrBot is a loosely coupled, asynchronous chatbot and development framework that supports multi-platform deployment, featuring an easy-to-use plugin system and comprehensive Large Language Model (LLM) integration capabilities.
@@ -49,7 +49,7 @@ Requires Python (>3.10). See docs: [Windows Installer Guide](https://astrbot.app
#### Replit Deployment
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### CasaOS Deployment
@@ -67,8 +67,8 @@ See docs: [Source Code Deployment](https://astrbot.app/deploy/astrbot/cli.html)
| QQ (Official Bot) | ✔ | Private/Group chats | Text, Images |
| QQ (OneBot) | ✔ | Private/Group chats | Text, Images, Voice |
| WeChat (Personal) | ✔ | Private/Group chats | Text, Images, Voice |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| [Telegram](https://github.com/AstrBotDevs/AstrBot_plugin_telegram) | ✔ | Private/Group chats | Text, Images |
| [WeChat Work](https://github.com/AstrBotDevs/AstrBot_plugin_wecom) | ✔ | Private chats | Text, Images, Voice |
| Feishu | ✔ | Group chats | Text, Images |
| WeChat Open Platform | 🚧 | Planned | - |
| Discord | 🚧 | Planned | - |
@@ -157,7 +157,7 @@ _✨ Built-in Web Chat Interface ✨_
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=AstrBotDevs/AstrBot&type=Date)](https://star-history.com/#AstrBotDevs/AstrBot&Date)
</div>
@@ -169,7 +169,7 @@ _✨ Built-in Web Chat Interface ✨_
<!-- ## ✨ ATRI [Beta]
Available as plugin: [astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
Available as plugin: [astrbot_plugin_atri](https://github.com/AstrBotDevs/AstrBot_plugin_atri)
1. Qwen1.5-7B-Chat Lora model fine-tuned with ATRI character data
2. Long-term memory

View File

@@ -10,16 +10,16 @@ _✨ 簡単に使えるマルチプラットフォーム LLM チャットボッ
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/AstrBotDevs/AstrBot)](https://github.com/AstrBotDevs/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[![codecov](https://codecov.io/gh/AstrBotDevs/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/AstrBotDevs/AstrBot)
<a href="https://astrbot.app/">ドキュメントを見る</a>
<a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題を報告する</a>
</div>
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデルLLM接続機能を備えたチャットボットおよび開発フレームワークです。
@@ -50,7 +50,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
#### Replit デプロイ
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### CasaOS デプロイ

0
astrbot.lock Normal file
View File

View File

@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
tools: list[str | FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -40,8 +40,15 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
timeout = cfg.get("timeout", 10)
try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
if transport_type == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
@@ -92,7 +99,7 @@ class MCPClient:
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
@@ -121,7 +128,14 @@ class MCPClient:
if not success:
raise Exception(error_msg)
if cfg.get("transport") != "streamable_http":
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
if transport_type != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
@@ -134,7 +148,7 @@ class MCPClient:
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
@@ -159,7 +173,7 @@ class MCPClient:
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
@@ -198,6 +212,8 @@ class MCPClient:
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response

View File

@@ -198,9 +198,49 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
if not func_tool:
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: 未找到工具 {func_tool_name}",
)
)
continue
valid_params = {} # 参数过滤:只传递函数实际需要的参数
# 获取实际的 handler 函数
if func_tool.handler:
logger.debug(
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}"
)
if func_tool.parameters and func_tool.parameters.get("properties"):
expected_params = set(func_tool.parameters["properties"].keys())
valid_params = {
k: v
for k, v in func_tool_args.items()
if k in expected_params
}
# 记录被忽略的参数
ignored_params = set(func_tool_args.keys()) - set(
valid_params.keys()
)
if ignored_params:
logger.warning(
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}"
)
else:
# 如果没有 handler如 MCP 工具),使用所有参数
valid_params = func_tool_args
logger.warning(f"工具 {func_tool_name} 没有 handler使用所有参数")
try:
await self.agent_hooks.on_tool_start(
self.run_context, func_tool, func_tool_args
self.run_context, func_tool, valid_params
)
except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
@@ -208,11 +248,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
executor = self.tool_executor.execute(
tool=func_tool,
run_context=self.run_context,
**func_tool_args,
**valid_params, # 只传递有效的参数
)
async for resp in executor:
_final_resp: CallToolResult | None = None
async for resp in executor: # type: ignore
if isinstance(resp, CallToolResult):
res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -258,7 +301,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain(
type="tool_direct_result"
).base64_image(res.content[0].data)
).base64_image(resource.blob)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -269,17 +312,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain().message("返回的数据类型不受支持。")
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
@@ -289,27 +321,18 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else:
# 不应该出现其他类型
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool, func_tool_args, _final_resp
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
self.run_context.event.clear_result()
except Exception as e:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from typing import Awaitable, Callable, Literal, Any, Optional
from .mcp_client import MCPClient
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str | None = None
name: str
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
@@ -51,7 +51,7 @@ class ToolSet:
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] = None):
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
@@ -79,7 +79,13 @@ class ToolSet:
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
@@ -104,7 +110,7 @@ class ToolSet:
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
def get_func(self, name: str) -> FunctionTool | None:
"""Get all function tools."""
return self.get_tool(name)
@@ -125,7 +131,11 @@ class ToolSet:
},
}
if tool.parameters.get("properties") or not omit_empty_parameter_field:
if (
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
@@ -135,14 +145,14 @@ class ToolSet:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
"input_schema": input_schema,
}
result.append(tool_def)
return result
@@ -210,14 +220,15 @@ class ToolSet:
return result
tools = [
{
tools = []
for tool in self.tools:
d = {
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)
declarations = {}
if tools:

View File

@@ -9,3 +9,4 @@ class AstrAgentContext:
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -5,6 +5,7 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from typing import TypeVar, TypedDict
@@ -15,14 +16,12 @@ class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default"
umop: list[str] # Unified Message Origin Pattern
name: str
path: str # File name to the configuration file
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
id="default",
umop=["::"],
name="default",
path=ASTRBOT_CONFIG_PATH,
)
@@ -31,8 +30,14 @@ DEFAULT_CONFIG_CONF_INFO = ConfInfo(
class AstrBotConfigManager:
"""A class to manage the system configuration of AstrBot, aka ACM"""
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
def __init__(
self,
default_config: AstrBotConfig,
ucr: UmopConfigRouter,
sp: SharedPreferences,
):
self.sp = sp
self.ucr = ucr
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
@@ -63,24 +68,15 @@ class AstrBotConfigManager:
)
continue
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "umop": list, "path": str, "name": str }
# uuid -> { "path": str, "name": str }
abconf_data = self._get_abconf_data()
if isinstance(umo, MessageSession):
umo = str(umo)
else:
@@ -89,10 +85,13 @@ class AstrBotConfigManager:
except Exception:
return DEFAULT_CONFIG_CONF_INFO
for uuid_, meta in abconf_data.items():
for pattern in meta["umop"]:
if self._is_umo_match(pattern, umo):
return ConfInfo(**meta, id=uuid_)
conf_id = self.ucr.get_conf_id_for_umop(umo)
if conf_id:
meta = abconf_data.get(conf_id)
if meta and isinstance(meta, dict):
# the bind relation between umo and conf is defined in ucr now, so we remove "umop" here
meta.pop("umop", None)
return ConfInfo(**meta, id=conf_id)
return DEFAULT_CONFIG_CONF_INFO
@@ -100,23 +99,14 @@ class AstrBotConfigManager:
self,
abconf_path: str,
abconf_id: str,
umo_parts: list[str] | list[MessageSession],
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"umop": umo_parts,
"path": abconf_path,
"name": random_word,
}
@@ -153,29 +143,26 @@ class AstrBotConfigManager:
def get_conf_list(self) -> list[ConfInfo]:
"""获取所有配置文件的元数据列表"""
conf_list = []
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
abconf_mapping = self._get_abconf_data()
for uuid_, meta in abconf_mapping.items():
if not isinstance(meta, dict):
continue
meta.pop("umop", None)
conf_list.append(ConfInfo(**meta, id=uuid_))
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
return conf_list
def create_conf(
self,
umo_parts: list[str] | list[MessageSession],
config: dict = DEFAULT_CONFIG,
name: str | None = None,
) -> str:
"""
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
"""
conf_uuid = str(uuid.uuid4())
conf_file_name = f"abconf_{conf_uuid}.json"
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
conf = AstrBotConfig(config_path=conf_path, default_config=config)
conf.save_config()
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
self._save_conf_mapping(conf_file_name, conf_uuid, abconf_name=name)
self.confs[conf_uuid] = conf
return conf_uuid
@@ -228,15 +215,12 @@ class AstrBotConfigManager:
logger.info(f"成功删除配置文件 {conf_id}")
return True
def update_conf_info(
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
) -> bool:
def update_conf_info(self, conf_id: str, name: str | None = None) -> bool:
"""更新配置文件信息
Args:
conf_id: 配置文件的 UUID
name: 新的配置文件名称 (可选)
umo_parts: 新的 UMO 部分列表 (可选)
Returns:
bool: 更新是否成功
@@ -255,18 +239,6 @@ class AstrBotConfigManager:
if name is not None:
abconf_data[conf_id]["name"] = name
# 更新 UMO 部分
if umo_parts is not None:
# 验证 UMO 部分格式
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data[conf_id]["umop"] = umo_parts
# 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data

View File

@@ -6,7 +6,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.1.2"
VERSION = "4.5.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -57,19 +57,22 @@ DEFAULT_CONFIG = {
"web_search": False,
"websearch_provider": "default",
"websearch_tavily_key": [],
"websearch_baidu_app_builder_key": "",
"web_search_link": False,
"display_reasoning_text": False,
"identifier": False,
"group_name_display": False,
"datetime_system_prompt": True,
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "",
"prompt_prefix": "{{prompt}}",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"streaming_segmented": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
"provider_stt_settings": {
"enable": False,
@@ -115,6 +118,15 @@ DEFAULT_CONFIG = {
"port": 6185,
},
"platform": [],
"platform_specific": {
# 平台特异配置:按平台分类,平台下按功能分组
"lark": {
"pre_ack_emoji": {"enable": False, "emojis": ["Typing"]},
},
"telegram": {
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
},
},
"wake_prefix": ["/"],
"log_level": "INFO",
"pip_install_arg": "",
@@ -122,8 +134,11 @@ DEFAULT_CONFIG = {
"persona": [], # deprecated
"timezone": "Asia/Shanghai",
"callback_api_base": "",
"default_kb_collection": "", # 默认知识库名称
"default_kb_collection": "", # 默认知识库名称, 已经过时
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
"kb_names": [], # 默认知识库名称列表
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
}
@@ -150,10 +165,11 @@ CONFIG_METADATA_2 = {
"enable": False,
"appid": "",
"secret": "",
"is_sandbox": False,
"callback_server_host": "0.0.0.0",
"port": 6196,
},
"QQ 个人号(aiocqhttp)": {
"QQ 个人号(OneBot v11)": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
@@ -161,7 +177,7 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199,
"ws_reverse_token": "",
},
"微信个人号(WeChatPadPro)": {
"WeChatPadPro": {
"id": "wechatpadpro",
"type": "wechatpadpro",
"enable": False,
@@ -197,6 +213,18 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6195,
},
"企业微信智能机器人": {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"enable": True,
"wecomaibot_init_respond_text": "💭 思考中...",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"token": "",
"encoding_aes_key": "",
"callback_server_host": "0.0.0.0",
"port": 6198,
},
"飞书(Lark)": {
"id": "lark",
"type": "lark",
@@ -235,6 +263,24 @@ CONFIG_METADATA_2 = {
"discord_guild_id_for_debug": "",
"discord_activity_name": "",
},
"Misskey": {
"id": "misskey",
"type": "misskey",
"enable": False,
"misskey_instance_url": "https://misskey.example",
"misskey_token": "",
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
# download / security options
"misskey_allow_insecure_downloads": False,
"misskey_download_timeout": 15,
"misskey_download_chunk_size": 65536,
"misskey_max_download_bytes": None,
"misskey_enable_file_upload": True,
"misskey_upload_concurrency": 3,
"misskey_upload_folder": "",
},
"Slack": {
"id": "slack",
"type": "slack",
@@ -252,43 +298,61 @@ CONFIG_METADATA_2 = {
"type": "satori",
"enable": False,
"satori_api_base_url": "http://localhost:5140/satori/v1",
"satori_endpoint": "ws://127.0.0.1:5140/satori/v1/events",
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
"satori_token": "",
"satori_auto_reconnect": True,
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
# "enable": False,
# "webchat_link_path": "",
# "webchat_present_type": "fullscreen",
# },
},
"items": {
# "webchat_link_path": {
# "description": "链接路径",
# "_special": "webchat_link_path",
# "type": "string",
# },
# "webchat_present_type": {
# "_special": "webchat_present_type",
# "description": "展现形式",
# "type": "string",
# "options": ["fullscreen", "embedded"],
# },
"satori_api_base_url": {
"description": "Satori API Base URL",
"description": "Satori API 终结点",
"type": "string",
"hint": "The base URL for the Satori API.",
"hint": "Satori API 的基础地址。",
},
"satori_endpoint": {
"description": "Satori WebSocket Endpoint",
"description": "Satori WebSocket 终结点",
"type": "string",
"hint": "The WebSocket endpoint for Satori events.",
"hint": "Satori 事件的 WebSocket 端点。",
},
"satori_token": {
"description": "Satori Token",
"description": "Satori 令牌",
"type": "string",
"hint": "The token used for authenticating with the Satori API.",
"hint": "用于 Satori API 身份验证的令牌。",
},
"satori_auto_reconnect": {
"description": "Enable Auto Reconnect",
"description": "启用自动重连",
"type": "bool",
"hint": "Whether to automatically reconnect the WebSocket on disconnection.",
"hint": "断开连接时是否自动重新连接 WebSocket。",
},
"satori_heartbeat_interval": {
"description": "Satori Heartbeat Interval",
"description": "Satori 心跳间隔",
"type": "int",
"hint": "The interval (in seconds) for sending heartbeat messages.",
"hint": "发送心跳消息的间隔(秒)。",
},
"satori_reconnect_delay": {
"description": "Satori Reconnect Delay",
"description": "Satori 重连延迟",
"type": "int",
"hint": "The delay (in seconds) before attempting to reconnect.",
"hint": "尝试重新连接前的延迟时间(秒)。",
},
"slack_connection_mode": {
"description": "Slack Connection Mode",
@@ -336,6 +400,67 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
},
"misskey_instance_url": {
"description": "Misskey 实例 URL",
"type": "string",
"hint": "例如 https://misskey.example填写 Bot 账号所在的 Misskey 实例地址",
},
"misskey_token": {
"description": "Misskey Access Token",
"type": "string",
"hint": "连接服务设置生成的 API 鉴权访问令牌Access token",
},
"misskey_default_visibility": {
"description": "默认帖子可见性",
"type": "string",
"options": ["public", "home", "followers"],
"hint": "机器人发帖时的默认可见性设置。public公开home主页时间线followers仅关注者。",
},
"misskey_local_only": {
"description": "仅限本站(不参与联合)",
"type": "bool",
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
},
"misskey_enable_chat": {
"description": "启用聊天消息响应",
"type": "bool",
"hint": "启用后,机器人将会监听和响应私信聊天消息",
},
"misskey_enable_file_upload": {
"description": "启用文件上传到 Misskey",
"type": "bool",
"hint": "启用后,适配器会尝试将消息链中的文件上传到 Misskey。URL 文件会先尝试服务器端上传,异步上传失败时会回退到下载后本地上传。",
},
"misskey_allow_insecure_downloads": {
"description": "允许不安全下载(禁用 SSL 验证)",
"type": "bool",
"hint": "当远端服务器存在证书问题导致无法正常下载时,自动禁用 SSL 验证作为回退方案。适用于某些图床的证书配置问题。启用有安全风险,仅在必要时使用。",
},
"misskey_download_timeout": {
"description": "远端下载超时时间(秒)",
"type": "int",
"hint": "下载远程文件时的超时时间(秒),用于异步上传回退到本地上传的场景。",
},
"misskey_download_chunk_size": {
"description": "流式下载分块大小(字节)",
"type": "int",
"hint": "流式下载和计算 MD5 时使用的每次读取字节数,过小会增加开销,过大会占用内存。",
},
"misskey_max_download_bytes": {
"description": "最大允许下载字节数(超出则中止)",
"type": "int",
"hint": "如果希望限制下载文件的最大大小以防止 OOM请填写最大字节数留空或 null 表示不限制。",
},
"misskey_upload_concurrency": {
"description": "并发上传限制",
"type": "int",
"hint": "同时进行的文件上传任务上限(整数,默认 3",
},
"misskey_upload_folder": {
"description": "上传到网盘的目标文件夹 ID",
"type": "string",
"hint": "可选:填写 Misskey 网盘中目标文件夹的 ID上传的文件将放置到该文件夹内。留空则使用账号网盘根目录。",
},
"telegram_command_register": {
"description": "Telegram 命令注册",
"type": "bool",
@@ -387,24 +512,38 @@ CONFIG_METADATA_2 = {
"hint": "启用后,机器人可以接收到频道的私聊消息。",
},
"ws_reverse_host": {
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
"description": "反向 Websocket 主机",
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号",
"hint": "AstrBot 将作为服务器端",
},
"ws_reverse_port": {
"description": "反向 Websocket 端口",
"type": "int",
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
},
"ws_reverse_token": {
"description": "反向 Websocket Token",
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
"hint": "反向 Websocket Token。未设置则不启用 Token 验证。",
},
"wecom_ai_bot_name": {
"description": "企业微信智能机器人的名字",
"type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。",
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。",
},
"wecomaibot_friend_message_welcome_text": {
"description": "企业微信智能机器人私聊欢迎语",
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
"hint": "请务必填,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
"hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
},
"discord_token": {
"description": "Discord Bot Token",
@@ -729,7 +868,7 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
"modalities": ["text", "tool_use"],
},
"302.AI": {
"id": "302ai",
@@ -775,6 +914,21 @@ CONFIG_METADATA_2 = {
},
"custom_extra_body": {},
},
"小马算力": {
"id": "tokenpony",
"provider": "tokenpony",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.tokenpony.cn/v1",
"timeout": 120,
"model_config": {
"model": "kimi-k2-instruct-0905",
"temperature": 0.7,
},
"custom_extra_body": {},
},
"优云智算": {
"id": "compshare",
"provider": "compshare",
@@ -832,6 +986,18 @@ CONFIG_METADATA_2 = {
"timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
},
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
@@ -983,6 +1149,7 @@ CONFIG_METADATA_2 = {
"timeout": "20",
},
"阿里云百炼 TTS(API)": {
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
"id": "dashscope_tts",
"provider": "dashscope",
"type": "dashscope_tts",
@@ -1250,6 +1417,7 @@ CONFIG_METADATA_2 = {
"description": "嵌入维度",
"type": "int",
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
"_special": "get_embedding_dim",
},
"embedding_model": {
"description": "嵌入模型",
@@ -1362,11 +1530,7 @@ CONFIG_METADATA_2 = {
"description": "服务订阅密钥",
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
},
"dashscope_tts_voice": {
"description": "语音合成模型",
"type": "string",
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
},
"dashscope_tts_voice": {"description": "音色", "type": "string"},
"gm_resp_image_modal": {
"description": "启用图片模态",
"type": "bool",
@@ -1698,6 +1862,26 @@ CONFIG_METADATA_2 = {
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
"obvious": True,
},
"coze_api_key": {
"description": "Coze API Key",
"type": "string",
"hint": "Coze API 密钥,用于访问 Coze 服务。",
},
"bot_id": {
"description": "Bot ID",
"type": "string",
"hint": "Coze 机器人的 ID在 Coze 平台上创建机器人后获得。",
},
"coze_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
},
"auto_save_history": {
"description": "由 Coze 管理对话记录",
"type": "bool",
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
},
},
},
"provider_settings": {
@@ -1724,6 +1908,9 @@ CONFIG_METADATA_2 = {
"identifier": {
"type": "bool",
},
"group_name_display": {
"type": "bool",
},
"datetime_system_prompt": {
"type": "bool",
},
@@ -1752,6 +1939,10 @@ CONFIG_METADATA_2 = {
"description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
},
},
"provider_stt_settings": {
@@ -1874,6 +2065,9 @@ CONFIG_METADATA_2 = {
"default_kb_collection": {
"type": "string",
},
"kb_names": {"type": "list", "items": {"type": "string"}},
"kb_fusion_top_k": {"type": "int", "default": 20},
"kb_final_top_k": {"type": "int", "default": 5},
},
},
}
@@ -1903,17 +2097,33 @@ CONFIG_METADATA_3 = {
"_special": "select_provider",
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
},
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",
"hint": "STT 总开关。",
},
"provider_stt_settings.provider_id": {
"description": "语音转文本模型",
"description": "默认语音转文本模型",
"type": "string",
"hint": "留空代表不使用",
"hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型",
"_special": "select_provider_stt",
"condition": {
"provider_stt_settings.enable": True,
},
},
"provider_tts_settings.enable": {
"description": "启用文本转语音",
"type": "bool",
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
},
"provider_tts_settings.provider_id": {
"description": "文本转语音模型",
"description": "默认文本转语音模型",
"type": "string",
"hint": "留空代表不使用",
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型",
"_special": "select_provider_tts",
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": {
"description": "图片转述提示词",
@@ -1936,10 +2146,22 @@ CONFIG_METADATA_3 = {
"description": "知识库",
"type": "object",
"items": {
"default_kb_collection": {
"description": "默认使用的知识库",
"type": "string",
"kb_names": {
"description": "知识库列表",
"type": "list",
"items": {"type": "string"},
"_special": "select_knowledgebase",
"hint": "支持多选",
},
"kb_fusion_top_k": {
"description": "融合检索结果数",
"type": "int",
"hint": "多个知识库检索结果融合后的返回结果数量",
},
"kb_final_top_k": {
"description": "最终返回结果数",
"type": "int",
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
},
},
},
@@ -1954,7 +2176,7 @@ CONFIG_METADATA_3 = {
"provider_settings.websearch_provider": {
"description": "网页搜索提供商",
"type": "string",
"options": ["default", "tavily"],
"options": ["default", "tavily", "baidu_ai_search"],
},
"provider_settings.websearch_tavily_key": {
"description": "Tavily API Key",
@@ -1965,6 +2187,14 @@ CONFIG_METADATA_3 = {
"provider_settings.websearch_provider": "tavily",
},
},
"provider_settings.websearch_baidu_app_builder_key": {
"description": "百度千帆智能云 APP Builder API Key",
"type": "string",
"hint": "参考https://console.bce.baidu.com/iam/#/iam/apikey/list",
"condition": {
"provider_settings.websearch_provider": "baidu_ai_search",
},
},
"provider_settings.web_search_link": {
"description": "显示来源引用",
"type": "bool",
@@ -1983,6 +2213,11 @@ CONFIG_METADATA_3 = {
"description": "用户识别",
"type": "bool",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",
@@ -1995,6 +2230,10 @@ CONFIG_METADATA_3 = {
"description": "工具调用轮数上限",
"type": "int",
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
"provider_settings.streaming_response": {
"description": "流式回复",
"type": "bool",
@@ -2016,12 +2255,14 @@ CONFIG_METADATA_3 = {
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
},
"provider_settings.prompt_prefix": {
"description": "额外前缀提示词",
"description": "用户提示词",
"type": "string",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
},
"provider_settings.dual_output": {
"provider_tts_settings.dual_output": {
"description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool",
},
@@ -2202,6 +2443,32 @@ CONFIG_METADATA_3 = {
"description": "用户权限不足时是否回复",
"type": "bool",
},
"platform_specific.lark.pre_ack_emoji.enable": {
"description": "[飞书] 启用预回应表情",
"type": "bool",
},
"platform_specific.lark.pre_ack_emoji.emojis": {
"description": "表情列表(飞书表情枚举名)",
"type": "list",
"items": {"type": "string"},
"hint": "表情枚举名参考https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce",
"condition": {
"platform_specific.lark.pre_ack_emoji.enable": True,
},
},
"platform_specific.telegram.pre_ack_emoji.enable": {
"description": "[Telegram] 启用预回应表情",
"type": "bool",
},
"platform_specific.telegram.pre_ack_emoji.emojis": {
"description": "表情列表Unicode",
"type": "list",
"items": {"type": "string"},
"hint": "Telegram 仅支持固定反应集合参考https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9",
"condition": {
"platform_specific.telegram.pre_ack_emoji.enable": True,
},
},
},
},
},

View File

@@ -7,7 +7,7 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
import json
from astrbot.core import sp
from typing import Dict, List
from typing import Dict, List, Callable, Awaitable
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
@@ -20,6 +20,38 @@ class ConversationManager:
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
# 会话删除回调函数列表(用于级联清理,如知识库配置)
self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = []
def register_on_session_deleted(
self, callback: Callable[[str], Awaitable[None]]
) -> None:
"""注册会话删除回调函数
其他模块可以注册回调来响应会话删除事件,实现级联清理。
例如:知识库模块可以注册回调来清理会话的知识库配置。
Args:
callback: 回调函数接收会话ID (unified_msg_origin) 作为参数
"""
self._on_session_deleted_callbacks.append(callback)
async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
"""触发会话删除回调
Args:
unified_msg_origin: 会话ID
"""
for callback in self._on_session_deleted_callbacks:
try:
await callback(unified_msg_origin)
except Exception as e:
from astrbot.core import logger
logger.error(
f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}"
)
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
@@ -87,17 +119,28 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
if f:
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid == conversation_id:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
# 触发会话删除回调(级联清理)
await self._trigger_session_deleted(unified_msg_origin)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID

View File

@@ -17,7 +17,6 @@ import os
from .event_bus import EventBus
from . import astrbot_config, html_renderer
from asyncio import Queue
from typing import List
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
@@ -26,14 +25,17 @@ from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, sp
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
class AstrBotCoreLifecycle:
@@ -84,11 +86,21 @@ class AstrBotCoreLifecycle:
await html_renderer.initialize()
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config, sp=sp
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
)
# 4.5 to 4.6 migration for umop_config_router
try:
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
except Exception as e:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列
self.event_queue = Queue()
@@ -110,6 +122,9 @@ class AstrBotCoreLifecycle:
# 初始化平台消息历史管理器
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
@@ -121,6 +136,7 @@ class AstrBotCoreLifecycle:
self.platform_message_history_manager,
self.persona_mgr,
self.astrbot_config_mgr,
self.kb_manager,
)
# 初始化插件管理器
@@ -132,8 +148,9 @@ class AstrBotCoreLifecycle:
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
# 初始化消息事件流水线调度器
await self.kb_manager.initialize()
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
# 初始化更新器
@@ -148,7 +165,7 @@ class AstrBotCoreLifecycle:
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = []
self.curr_tasks: list[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
@@ -233,6 +250,7 @@ class AstrBotCoreLifecycle:
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
@@ -248,12 +266,13 @@ class AstrBotCoreLifecycle:
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
await self.kb_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def load_platform(self) -> List[asyncio.Task]:
def load_platform(self) -> list[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()

View File

@@ -154,12 +154,17 @@ class BaseDatabase(abc.ABC):
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def delete_conversations_by_user_id(self, user_id: str) -> None:
"""Delete all conversations for a specific user."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,
platform_id: str,
user_id: str,
content: list[dict],
content: dict,
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
@@ -282,3 +287,14 @@ class BaseDatabase(abc.ABC):
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation."""
# ...
@abc.abstractmethod
async def get_session_conversations(
self,
page: int = 1,
page_size: int = 20,
search_query: str | None = None,
platform: str | None = None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
...

View File

@@ -0,0 +1,44 @@
from astrbot.api import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.umop_config_router import UmopConfigRouter
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
abconf_data = acm.abconf_data
if not isinstance(abconf_data, dict):
# should be unreachable
logger.warning(
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
)
return
# 如果任何一项带有 umop则说明需要迁移
need_migration = False
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
need_migration = True
break
if not need_migration:
return
logger.info("Starting migration from version 4.5 to 4.6")
# extract umo->conf_id mapping
umo_to_conf_id = {}
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
umop_ls = conf_info.pop("umop")
if not isinstance(umop_ls, list):
continue
for umo in umop_ls:
if isinstance(umo, str) and umo not in umo_to_conf_id:
umo_to_conf_id[umo] = conf_id
# update the abconf data
await sp.global_put("abconf_mapping", abconf_data)
# update the umop config router
await ucr.update_routing_data(umo_to_conf_id)
logger.info("Migration from version 45 to 46 completed successfully")

View File

@@ -75,7 +75,9 @@ class Persona(SQLModel, table=True):
__tablename__ = "personas"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
@@ -135,7 +137,9 @@ class PlatformMessageHistory(SQLModel, table=True):
__tablename__ = "platform_message_history"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
@@ -158,8 +162,8 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments"
inner_attachment_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
inner_attachment_id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
attachment_id: str = Field(
max_length=36,

View File

@@ -15,9 +15,8 @@ from astrbot.core.db.po import (
SQLModel,
)
from sqlalchemy import select, update, delete, text
from sqlmodel import select, update, delete, text, func, or_, desc, col
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
@@ -33,6 +32,12 @@ class SQLiteDatabase(BaseDatabase):
"""Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
# ====
@@ -41,10 +46,10 @@ class SQLiteDatabase(BaseDatabase):
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime = None,
platform_id,
platform_type,
count=1,
timestamp=None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
@@ -75,7 +80,9 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
select(func.count(col(PlatformStat.platform_id))).select_from(
PlatformStat
)
)
count = result.scalar_one_or_none()
return count if count is not None else 0
@@ -95,7 +102,7 @@ class SQLiteDatabase(BaseDatabase):
"""),
{"start_time": start_time},
)
return result.scalars().all()
return list(result.scalars().all())
# ====
# Conversation Management
@@ -111,7 +118,7 @@ class SQLiteDatabase(BaseDatabase):
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(ConversationV2.created_at.desc())
query = query.order_by(desc(ConversationV2.created_at))
result = await session.execute(query)
return result.scalars().all()
@@ -129,7 +136,7 @@ class SQLiteDatabase(BaseDatabase):
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(ConversationV2.created_at.desc())
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
@@ -150,11 +157,26 @@ class SQLiteDatabase(BaseDatabase):
if platform_ids:
base_query = base_query.where(
ConversationV2.platform_id.in_(platform_ids)
col(ConversationV2.platform_id).in_(platform_ids)
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
ConversationV2.title.ilike(f"%{search_query}%")
or_(
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
)
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"])
)
# Get total count matching the filters
@@ -165,7 +187,7 @@ class SQLiteDatabase(BaseDatabase):
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(ConversationV2.created_at.desc())
base_query.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
@@ -211,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
ConversationV2.conversation_id == cid
col(ConversationV2.conversation_id) == cid
)
values = {}
if title is not None:
@@ -231,9 +253,126 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid
)
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id"
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
)
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern)
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
)
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern)
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history(
self,
platform_id,
@@ -267,9 +406,9 @@ class SQLiteDatabase(BaseDatabase):
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
PlatformMessageHistory.created_at < cutoff_time,
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) < cutoff_time,
)
)
@@ -286,7 +425,7 @@ class SQLiteDatabase(BaseDatabase):
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(PlatformMessageHistory.created_at.desc())
.order_by(desc(PlatformMessageHistory.created_at))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
@@ -308,7 +447,7 @@ class SQLiteDatabase(BaseDatabase):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.id == attachment_id)
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
@@ -351,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(Persona.persona_id == persona_id)
query = update(Persona).where(col(Persona.persona_id) == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
@@ -371,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(Persona.persona_id == persona_id)
delete(Persona).where(col(Persona.persona_id) == persona_id)
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
@@ -426,9 +565,9 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
)
)
await session.commit()
@@ -440,7 +579,8 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope, Preference.scope_id == scope_id
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
)
)
await session.commit()
@@ -467,7 +607,7 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=data.timestamp.timestamp(),
timestamp=int(data.timestamp.timestamp()),
)
)
return deprecated_stats
@@ -525,7 +665,7 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=start_time.timestamp(),
timestamp=int(start_time.timestamp()),
)
)
return deprecated_stats

View File

@@ -16,14 +16,42 @@ class BaseVecDB:
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> int:
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
...
@abc.abstractmethod
async def retrieve(
self,
query: str,
top_k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""
搜索最相似的文档。
Args:
@@ -44,3 +72,6 @@ class BaseVecDB:
bool: 删除是否成功
"""
...
@abc.abstractmethod
async def close(self): ...

View File

@@ -1,59 +1,224 @@
import aiosqlite
import os
import json
from datetime import datetime
from contextlib import asynccontextmanager
from sqlalchemy import Text, Column
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
from astrbot.core import logger
class BaseDocModel(SQLModel, table=False):
metadata = MetaData()
class Document(BaseDocModel, table=True):
"""SQLModel for documents table."""
__tablename__ = "documents" # type: ignore
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
doc_id: str = Field(nullable=False)
text: str = Field(nullable=False)
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
created_at: datetime | None = Field(default=None)
updated_at: datetime | None = Field(default=None)
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.engine: AsyncEngine | None = None
self.async_session_maker: sessionmaker | None = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
await self.connect()
async with self.engine.begin() as conn: # type: ignore
# Create tables using SQLModel
await conn.run_sync(BaseDocModel.metadata.create_all)
try:
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
)
)
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN user_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
)
)
# Create indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
)
)
except BaseException:
pass
await conn.commit()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
if self.engine is None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.async_session_maker = sessionmaker(
self.engine, # type: ignore
class_=AsyncSession,
expire_on_commit=False,
) # type: ignore
async def get_documents(self, metadata_filters: dict, ids: list = None):
@asynccontextmanager
async def get_session(self):
"""Context manager for database sessions."""
async with self.async_session_maker() as session: # type: ignore
yield session
async def get_documents(
self,
metadata_filters: dict,
ids: list | None = None,
offset: int | None = 0,
limit: int | None = 100,
) -> list[dict]:
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
ids (list | None): Optional list of document IDs to filter.
offset (int | None): Offset for pagination.
limit (int | None): Limit for pagination.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
list: The list of documents that match the filters.
"""
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
if self.engine is None:
logger.warning(
"Database connection is not initialized, returning empty result"
)
return []
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async with self.get_session() as session:
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
if ids is not None and len(ids) > 0:
valid_ids = [int(i) for i in ids if i != -1]
if valid_ids:
query = query.where(col(Document.id).in_(valid_ids))
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await session.execute(query)
documents = result.scalars().all()
return [self._document_to_dict(doc) for doc in documents]
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
"""Insert a single document and return its integer ID.
Args:
doc_id (str): The document ID (UUID string).
text (str): The document text.
metadata (dict): The document metadata.
Returns:
int: The integer ID of the inserted document.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(document)
await session.flush() # Flush to get the ID
return document.id # type: ignore
async def insert_documents_batch(
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
) -> list[int]:
"""Batch insert documents and return their integer IDs.
Args:
doc_ids (list[str]): List of document IDs (UUID strings).
texts (list[str]): List of document texts.
metadatas (list[dict]): List of document metadata.
Returns:
list[int]: List of integer IDs of the inserted documents.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
import json
documents = []
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
documents.append(document)
session.add(document)
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
"""Delete a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to delete.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
await session.delete(document)
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
@@ -62,28 +227,91 @@ class DocumentStorage:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
dict: The document data or None if not found.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
return self._document_to_dict(document)
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
"""Update a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
"""Delete documents by their metadata filters.
Args:
metadata_filters (dict): The metadata filters to apply.
"""
if self.engine is None:
logger.warning(
"Database connection is not initialized, skipping delete operation"
)
await self.connection.commit()
return
async with self.get_session() as session:
async with session.begin():
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
documents = result.scalars().all()
for doc in documents:
await session.delete(doc)
async def count_documents(self, metadata_filters: dict | None = None) -> int:
"""Count documents in the database.
Args:
metadata_filters (dict | None): Metadata filters to apply.
Returns:
int: The count of documents.
"""
if self.engine is None:
logger.warning("Database connection is not initialized, returning 0")
return 0
async with self.get_session() as session:
query = select(func.count(col(Document.id)))
if metadata_filters:
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
@@ -91,11 +319,38 @@ class DocumentStorage:
Returns:
list: A list of user IDs.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = text(
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
)
result = await session.execute(query)
rows = result.fetchall()
return [row[0] for row in rows]
def _document_to_dict(self, document: Document) -> dict:
"""Convert a Document model to a dictionary.
Args:
document (Document): The document to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": document.id,
"doc_id": document.doc_id,
"text": document.text,
"metadata": document.metadata_,
"created_at": document.created_at.isoformat()
if isinstance(document.created_at, datetime)
else document.created_at,
"updated_at": document.updated_at.isoformat()
if isinstance(document.updated_at, datetime)
else document.updated_at,
}
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
@@ -104,6 +359,8 @@ class DocumentStorage:
Returns:
dict: The converted dictionary.
Note: This method is kept for backward compatibility but is no longer used internally.
"""
return {
"id": row[0],
@@ -116,6 +373,7 @@ class DocumentStorage:
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None
if self.engine:
await self.engine.dispose()
self.engine = None
self.async_session_maker = None

View File

@@ -9,7 +9,7 @@ import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str = None):
def __init__(self, dimension: int, path: str | None = None):
self.dimension = dimension
self.path = path
self.index = None
@@ -18,7 +18,6 @@ class EmbeddingStorage:
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
@@ -29,12 +28,29 @@ class EmbeddingStorage:
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.storage[id] = vector
await self.save_index()
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
"""批量插入向量
Args:
vectors (np.ndarray): 要插入的向量数组
ids (list[int]): 向量的ID列表
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vectors.shape[1] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
)
self.index.add_with_ids(vectors, np.array(ids))
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
@@ -46,10 +62,22 @@ class EmbeddingStorage:
Returns:
tuple: (距离, 索引)
"""
assert self.index is not None, "FAISS index is not initialized."
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def delete(self, ids: list[int]):
"""删除向量
Args:
ids (list[int]): 要删除的向量ID列表
"""
assert self.index is not None, "FAISS index is not initialized."
id_array = np.array(ids, dtype=np.int64)
self.index.remove_ids(id_array)
await self.save_index()
async def save_index(self):
"""保存索引

View File

@@ -1,11 +1,12 @@
import uuid
import json
import time
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
from astrbot import logger
class FaissVecDB(BaseVecDB):
@@ -44,18 +45,56 @@ class FaissVecDB(BaseVecDB):
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
# 使用 DocumentStorage 的方法插入文档
int_id = await self.document_storage.insert_document(str_id, content, metadata)
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[int]:
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
metadatas = metadatas or [{} for _ in contents]
ids = ids or [str(uuid.uuid4()) for _ in contents]
start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
end = time.time()
logger.debug(
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
)
# 使用 DocumentStorage 的批量插入方法
int_ids = await self.document_storage.insert_documents_batch(
ids, contents, metadatas
)
# 批量插入向量到 FAISS
vectors_array = np.array(vectors).astype("float32")
await self.embedding_storage.insert_batch(vectors_array, int_ids)
return int_ids
async def retrieve(
self,
@@ -119,23 +158,42 @@ class FaissVecDB(BaseVecDB):
return top_k_results
async def delete(self, doc_id: int):
async def delete(self, doc_id: str):
"""
删除一条文档
删除一条文档chunk
"""
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
# 获得对应的 int id
result = await self.document_storage.get_document_by_doc_id(doc_id)
int_id = result["id"] if result else None
if int_id is None:
return
# 使用 DocumentStorage 的删除方法
await self.document_storage.delete_document_by_doc_id(doc_id)
await self.embedding_storage.delete([int_id])
async def close(self):
await self.document_storage.close()
async def count_documents(self) -> int:
async def count_documents(self, metadata_filter: dict | None = None) -> int:
"""
计算文档数量
Args:
metadata_filter (dict | None): 元数据过滤器
"""
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0
count = await self.document_storage.count_documents(
metadata_filters=metadata_filter or {}
)
return count
async def delete_documents(self, metadata_filters: dict):
"""
根据元数据过滤器删除文档
"""
docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters, offset=None, limit=None
)
doc_ids: list[int] = [doc["id"] for doc in docs]
await self.embedding_storage.delete(doc_ids)
await self.document_storage.delete_documents(metadata_filters=metadata_filters)

View File

@@ -23,7 +23,12 @@ class FileTokenService:
for token in expired_tokens:
self.staged_files.pop(token, None)
async def register_file(self, file_path: str, timeout: float = None) -> str:
async def check_token_expired(self, file_token: str) -> bool:
async with self.lock:
await self._cleanup_expired_tokens()
return file_token not in self.staged_files
async def register_file(self, file_path: str, timeout: float | None = None) -> str:
"""向令牌服务注册一个文件。
Args:

View File

@@ -41,10 +41,13 @@ class InitialLoader:
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
)
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
coro = self.dashboard_server.run()
if coro:
# 启动核心任务和仪表板服务器
task = asyncio.gather(core_task, coro)
else:
task = core_task
try:
await task # 整个AstrBot在这里运行
except asyncio.CancelledError:

View File

@@ -0,0 +1,11 @@
"""
文档分块模块
"""
from .base import BaseChunker
from .fixed_size import FixedSizeChunker
__all__ = [
"BaseChunker",
"FixedSizeChunker",
]

View File

@@ -0,0 +1,24 @@
"""文档分块器基类
定义了文档分块处理的抽象接口。
"""
from abc import ABC, abstractmethod
class BaseChunker(ABC):
"""分块器基类
所有分块器都应该继承此类并实现 chunk 方法。
"""
@abstractmethod
async def chunk(self, text: str, **kwargs) -> list[str]:
"""将文本分块
Args:
text: 输入文本
Returns:
list[str]: 分块后的文本列表
"""

View File

@@ -0,0 +1,57 @@
"""固定大小分块器
按照固定的字符数将文本分块,支持重叠区域。
"""
from .base import BaseChunker
class FixedSizeChunker(BaseChunker):
"""固定大小分块器
按照固定的字符数分块,并支持块之间的重叠。
"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
"""初始化分块器
Args:
chunk_size: 块的大小(字符数)
chunk_overlap: 块之间的重叠字符数
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
async def chunk(self, text: str, **kwargs) -> list[str]:
"""固定大小分块
Args:
text: 输入文本
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
Returns:
list[str]: 分块后的文本列表
"""
chunk_size = kwargs.get("chunk_size", self.chunk_size)
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = start + chunk_size
chunk = text[start:end]
if chunk:
chunks.append(chunk)
# 移动窗口,保留重叠部分
start = end - chunk_overlap
# 防止无限循环: 如果重叠过大,直接移到end
if start >= end or chunk_overlap >= chunk_size:
start = end
return chunks

View File

@@ -0,0 +1,155 @@
from collections.abc import Callable
from .base import BaseChunker
class RecursiveCharacterChunker(BaseChunker):
def __init__(
self,
chunk_size: int = 500,
chunk_overlap: int = 100,
length_function: Callable[[str], int] = len,
is_separator_regex: bool = False,
separators: list[str] | None = None,
):
"""
初始化递归字符文本分割器
Args:
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
length_function: 计算文本长度的函数
is_separator_regex: 分隔符是否为正则表达式
separators: 用于分割文本的分隔符列表,按优先级排序
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
self.is_separator_regex = is_separator_regex
# 默认分隔符列表,按优先级从高到低
self.separators = separators or [
"\n\n", # 段落
"\n", # 换行
"", # 中文句子
"", # 中文逗号
". ", # 句子
", ", # 逗号分隔
" ", # 单词
"", # 字符
]
async def chunk(self, text: str, **kwargs) -> list[str]:
"""
递归地将文本分割成块
Args:
text: 要分割的文本
chunk_size: 每个文本块的最大大小
chunk_overlap: 每个文本块之间的重叠部分大小
Returns:
分割后的文本块列表
"""
if not text:
return []
overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
chunk_size = kwargs.get("chunk_size", self.chunk_size)
text_length = self.length_function(text)
if text_length <= chunk_size:
return [text]
for separator in self.separators:
if separator == "":
return self._split_by_character(text, chunk_size, overlap)
if separator in text:
splits = text.split(separator)
# 重新添加分隔符(除了最后一个片段)
splits = [s + separator for s in splits[:-1]] + [splits[-1]]
splits = [s for s in splits if s]
if len(splits) == 1:
continue
# 递归合并分割后的文本块
final_chunks = []
current_chunk = []
current_chunk_length = 0
for split in splits:
split_length = self.length_function(split)
# 如果单个分割部分已经超过了chunk_size需要递归分割
if split_length > chunk_size:
# 先处理当前积累的块
if current_chunk:
combined_text = "".join(current_chunk)
final_chunks.extend(
await self.chunk(
combined_text,
chunk_size=chunk_size,
chunk_overlap=overlap,
)
)
current_chunk = []
current_chunk_length = 0
# 递归分割过大的部分
final_chunks.extend(
await self.chunk(
split, chunk_size=chunk_size, chunk_overlap=overlap
)
)
# 如果添加这部分会使当前块超过chunk_size
elif current_chunk_length + split_length > chunk_size:
# 合并当前块并添加到结果中
combined_text = "".join(current_chunk)
final_chunks.append(combined_text)
# 处理重叠部分
overlap_start = max(0, len(combined_text) - overlap)
if overlap_start > 0:
overlap_text = combined_text[overlap_start:]
current_chunk = [overlap_text, split]
current_chunk_length = (
self.length_function(overlap_text) + split_length
)
else:
current_chunk = [split]
current_chunk_length = split_length
else:
# 添加到当前块
current_chunk.append(split)
current_chunk_length += split_length
# 处理剩余的块
if current_chunk:
final_chunks.append("".join(current_chunk))
return final_chunks
return [text]
def _split_by_character(
self, text: str, chunk_size: int | None = None, overlap: int | None = None
) -> list[str]:
"""
按字符级别分割文本
Args:
text: 要分割的文本
Returns:
分割后的文本块列表
"""
chunk_size = chunk_size or self.chunk_size
overlap = overlap or self.chunk_overlap
result = []
for i in range(0, len(text), chunk_size - overlap):
end = min(i + chunk_size, len(text))
result.append(text[i:end])
if end == len(text):
break
return result

View File

@@ -0,0 +1,299 @@
from contextlib import asynccontextmanager
from pathlib import Path
from sqlmodel import col, desc
from sqlalchemy import text, func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
KBMedia,
KnowledgeBase,
)
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
class KBSQLiteDatabase:
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(BaseKBModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)"
)
)
await session.commit()
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
"""根据 ID 获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
"""根据名称获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(desc(KnowledgeBase.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KnowledgeBase.id)))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
"""根据 ID 获取文档"""
async with self.get_db() as session:
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.get_db() as session:
stmt = (
select(KBDocument)
.where(col(KBDocument.kb_id) == kb_id)
.offset(offset)
.limit(limit)
.order_by(desc(KBDocument.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KBDocument.id))).where(
col(KBDocument.kb_id) == kb_id
)
result = await session.execute(stmt)
return result.scalar() or 0
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
async with self.get_db() as session:
stmt = (
select(KBDocument, KnowledgeBase)
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
.where(col(KBDocument.doc_id) == doc_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
return {
"document": row[0],
"knowledge_base": row[1],
}
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session:
async with session.begin():
# 删除文档记录
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
await session.execute(delete_stmt)
await session.commit()
# 在 vec db 中删除相关向量
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
"""根据 ID 获取多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
"""更新知识库统计信息"""
chunk_cnt = await vec_db.count_documents()
async with self.get_db() as session:
async with session.begin():
update_stmt = (
update(KnowledgeBase)
.where(col(KnowledgeBase.kb_id) == kb_id)
.values(
doc_count=select(func.count(col(KBDocument.id)))
.where(col(KBDocument.kb_id) == kb_id)
.scalar_subquery(),
chunk_count=chunk_cnt,
)
)
await session.execute(update_stmt)
await session.commit()

View File

@@ -0,0 +1,348 @@
import uuid
import aiofiles
import json
from pathlib import Path
from .models import KnowledgeBase, KBDocument, KBMedia
from .kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.provider.manager import ProviderManager
from .parsers.util import select_parser
from .chunking.base import BaseChunker
from astrbot.core import logger
class KBHelper:
vec_db: BaseVecDB
kb: KnowledgeBase
def __init__(
self,
kb_db: KBSQLiteDatabase,
kb: KnowledgeBase,
provider_manager: ProviderManager,
kb_root_dir: str,
chunker: BaseChunker,
):
self.kb_db = kb_db
self.kb = kb
self.prov_mgr = provider_manager
self.kb_root_dir = kb_root_dir
self.chunker = chunker
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self):
await self._ensure_vec_db()
async def get_ep(self) -> EmbeddingProvider:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
self.kb.embedding_provider_id
) # type: ignore
if not ep:
raise ValueError(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
)
return ep
async def get_rp(self) -> RerankProvider | None:
if not self.kb.rerank_provider_id:
return None
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
self.kb.rerank_provider_id
) # type: ignore
if not rp:
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
)
return rp
async def _ensure_vec_db(self) -> FaissVecDB:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep = await self.get_ep()
rp = await self.get_rp()
vec_db = FaissVecDB(
doc_store_path=str(self.kb_dir / "doc.db"),
index_store_path=str(self.kb_dir / "index.faiss"),
embedding_provider=ep,
rerank_provider=rp,
)
await vec_db.initialize()
self.vec_db = vec_db
return vec_db
async def delete_vec_db(self):
"""删除知识库的向量数据库和所有相关文件"""
import shutil
await self.terminate()
if self.kb_dir.exists():
shutil.rmtree(self.kb_dir)
async def terminate(self):
if self.vec_db:
await self.vec_db.close()
async def upload_document(
self,
file_name: str,
file_content: bytes,
file_type: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
Args:
progress_callback: 进度回调函数,接收参数 (stage, current, total)
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
- current: 当前进度
- total: 总数
"""
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
media_paths: list[Path] = []
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
# async with aiofiles.open(file_path, "wb") as f:
# await f.write(file_content)
try:
# 阶段1: 解析文档
if progress_callback:
await progress_callback("parsing", 0, 100)
parser = await select_parser(f".{file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
if progress_callback:
await progress_callback("parsing", 100, 100)
# 保存媒体文件
saved_media = []
for media_item in media_items:
media = await self._save_media(
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 阶段2: 分块
if progress_callback:
await progress_callback("chunking", 0, 100)
chunks_text = await self.chunker.chunk(
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
contents = []
metadatas = []
for idx, chunk_text in enumerate(chunks_text):
contents.append(chunk_text)
metadatas.append(
{
"kb_id": self.kb.kb_id,
"kb_doc_id": doc_id,
"chunk_index": idx,
}
)
if progress_callback:
await progress_callback("chunking", 100, 100)
# 阶段3: 生成向量(带进度回调)
async def embedding_progress_callback(current, total):
if progress_callback:
await progress_callback("embedding", current, total)
await self.vec_db.insert_batch(
contents=contents,
metadatas=metadatas,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=embedding_progress_callback,
)
# 保存文档的元数据
doc = KBDocument(
doc_id=doc_id,
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
# file_path=str(file_path),
file_path="",
chunk_count=len(chunks_text),
media_count=0,
)
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
vec_db: FaissVecDB = self.vec_db # type: ignore
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
await self.refresh_kb()
await self.refresh_document(doc_id)
return doc
except Exception as e:
logger.error(f"上传文档失败: {e}")
# if file_path.exists():
# file_path.unlink()
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
raise e
async def list_documents(
self, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
return docs
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取单个文档"""
doc = await self.kb_db.get_document_by_id(doc_id)
return doc
async def delete_document(self, doc_id: str):
"""删除单个文档及其相关数据"""
await self.kb_db.delete_document_by_id(
doc_id=doc_id,
vec_db=self.vec_db, # type: ignore
)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
async def delete_chunk(self, chunk_id: str, doc_id: str):
"""删除单个文本块及其相关数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
await vec_db.delete(chunk_id)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
await self.refresh_document(doc_id)
async def refresh_kb(self):
if self.kb:
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
if kb:
self.kb = kb
async def refresh_document(self, doc_id: str) -> None:
"""更新文档的元数据"""
doc = await self.get_document(doc_id)
if not doc:
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
doc.chunk_count = chunk_count
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
await session.commit()
await session.refresh(doc)
async def get_chunks_by_doc_id(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[dict]:
"""获取文档的所有块及其元数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
chunks = await vec_db.document_storage.get_documents(
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
)
result = []
for chunk in chunks:
chunk_md = json.loads(chunk["metadata"])
result.append(
{
"chunk_id": chunk["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": chunk_md["kb_id"],
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
"char_count": len(chunk["text"]),
}
)
return result
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
"""获取文档的块数量"""
vec_db: FaissVecDB = self.vec_db # type: ignore
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
return count
async def _save_media(
self,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=self.kb.kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media

View File

@@ -0,0 +1,287 @@
import traceback
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.manager import ProviderManager
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.sparse_retriever import SparseRetriever
from .retrieval.rank_fusion import RankFusion
from .kb_db_sqlite import KBSQLiteDatabase
# from .chunking.fixed_size import FixedSizeChunker
from .chunking.recursive import RecursiveCharacterChunker
from .kb_helper import KBHelper
from .models import KnowledgeBase
FILES_PATH = "data/knowledge_base"
DB_PATH = Path(FILES_PATH) / "kb.db"
"""Knowledge Base storage root directory"""
CHUNKER = RecursiveCharacterChunker()
class KnowledgeBaseManager:
kb_db: KBSQLiteDatabase
retrieval_manager: RetrievalManager
def __init__(
self,
provider_manager: ProviderManager,
):
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
self.provider_manager = provider_manager
self._session_deleted_callback_registered = False
self.kb_insts: dict[str, KBHelper] = {}
async def initialize(self):
"""初始化知识库模块"""
try:
logger.info("正在初始化知识库模块...")
# 初始化数据库
await self._init_kb_database()
# 初始化检索管理器
sparse_retriever = SparseRetriever(self.kb_db)
rank_fusion = RankFusion(self.kb_db)
self.retrieval_manager = RetrievalManager(
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_db,
)
await self.load_kbs()
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self):
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
async def load_kbs(self):
"""加载所有知识库实例"""
kb_records = await self.kb_db.list_kbs()
for record in kb_records:
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=record,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
await kb_helper.initialize()
self.kb_insts[record.kb_id] = kb_helper
async def create_kb(
self,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper:
"""创建新的知识库实例"""
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
)
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=kb,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
)
await kb_helper.initialize()
self.kb_insts[kb.kb_id] = kb_helper
return kb_helper
async def get_kb(self, kb_id: str) -> KBHelper | None:
"""获取知识库实例"""
if kb_id in self.kb_insts:
return self.kb_insts[kb_id]
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
"""通过名称获取知识库实例"""
for kb_helper in self.kb_insts.values():
if kb_helper.kb.kb_name == kb_name:
return kb_helper
return None
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return False
await kb_helper.delete_vec_db()
async with self.kb_db.get_db() as session:
await session.delete(kb_helper.kb)
await session.commit()
self.kb_insts.pop(kb_id, None)
return True
async def list_kbs(self) -> list[KnowledgeBase]:
"""列出所有知识库实例"""
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
return kbs
async def update_kb(
self,
kb_id: str,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper | None:
"""更新知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return None
kb = kb_helper.kb
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
return kb_helper
async def retrieve(
self,
query: str,
kb_names: list[str],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> dict | None:
"""从指定知识库中检索相关内容"""
kb_ids = []
kb_id_helper_map = {}
for kb_name in kb_names:
if kb_helper := await self.get_kb_by_name(kb_name):
kb_ids.append(kb_helper.kb.kb_id)
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
if not kb_ids:
return {}
results = await self.retrieval_manager.retrieve(
query=query,
kb_ids=kb_ids,
kb_id_helper_map=kb_id_helper_map,
top_k_fusion=top_k_fusion,
top_m_final=top_m_final,
)
if not results:
return None
context_text = self._format_context(results)
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
"char_count": r.metadata.get("char_count", 0),
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
def _format_context(self, results: list[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
async def terminate(self):
"""终止所有知识库实例,关闭数据库连接"""
for kb_id, kb_helper in self.kb_insts.items():
try:
await kb_helper.terminate()
except Exception as e:
logger.error(f"关闭知识库 {kb_id} 失败: {e}")
self.kb_insts.clear()
# 关闭元数据数据库
if hasattr(self, "kb_db") and self.kb_db:
try:
await self.kb_db.close()
except Exception as e:
logger.error(f"关闭知识库元数据数据库失败: {e}")

View File

@@ -0,0 +1,114 @@
import uuid
from datetime import datetime, timezone
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
class BaseKBModel(SQLModel, table=False):
metadata = MetaData()
class KnowledgeBase(BaseKBModel, table=True):
"""知识库表
存储知识库的基本信息和统计数据。
"""
__tablename__ = "knowledge_bases" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
kb_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_name: str = Field(max_length=100, nullable=False)
description: str | None = Field(default=None, sa_type=Text)
emoji: str | None = Field(default="📚", max_length=10)
embedding_provider_id: str | None = Field(default=None, max_length=100)
rerank_provider_id: str | None = Field(default=None, max_length=100)
# 分块配置参数
chunk_size: int | None = Field(default=512, nullable=True)
chunk_overlap: int | None = Field(default=50, nullable=True)
# 检索配置参数
top_k_dense: int | None = Field(default=50, nullable=True)
top_k_sparse: int | None = Field(default=50, nullable=True)
top_m_final: int | None = Field(default=5, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
doc_count: int = Field(default=0, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"kb_name",
name="uix_kb_name",
),
)
class KBDocument(BaseKBModel, table=True):
"""文档表
存储上传到知识库的文档元数据。
"""
__tablename__ = "kb_documents" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
doc_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
kb_id: str = Field(max_length=36, nullable=False, index=True)
doc_name: str = Field(max_length=255, nullable=False)
file_type: str = Field(max_length=20, nullable=False)
file_size: int = Field(nullable=False)
file_path: str = Field(max_length=512, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
media_count: int = Field(default=0, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class KBMedia(BaseKBModel, table=True):
"""多媒体资源表
存储从文档中提取的图片、视频等多媒体资源。
"""
__tablename__ = "kb_media" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
media_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
doc_id: str = Field(max_length=36, nullable=False, index=True)
kb_id: str = Field(max_length=36, nullable=False, index=True)
media_type: str = Field(max_length=20, nullable=False)
file_name: str = Field(max_length=255, nullable=False)
file_path: str = Field(max_length=512, nullable=False)
file_size: int = Field(nullable=False)
mime_type: str = Field(max_length=100, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))

View File

@@ -0,0 +1,15 @@
"""
文档解析器模块
"""
from .base import BaseParser, MediaItem, ParseResult
from .text_parser import TextParser
from .pdf_parser import PDFParser
__all__ = [
"BaseParser",
"MediaItem",
"ParseResult",
"TextParser",
"PDFParser",
]

View File

@@ -0,0 +1,50 @@
"""文档解析器基类和数据结构
定义了文档解析器的抽象接口和相关数据类。
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class MediaItem:
"""多媒体项
表示从文档中提取的多媒体资源。
"""
media_type: str # image, video
file_name: str
content: bytes
mime_type: str
@dataclass
class ParseResult:
"""解析结果
包含解析后的文本内容和提取的多媒体资源。
"""
text: str
media: list[MediaItem]
class BaseParser(ABC):
"""文档解析器基类
所有文档解析器都应该继承此类并实现 parse 方法。
"""
@abstractmethod
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文档
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果
"""

View File

@@ -0,0 +1,25 @@
import io
import os
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
ParseResult,
)
from markitdown_no_magika import MarkItDown, StreamInfo
class MarkitdownParser(BaseParser):
"""解析 docx, xls, xlsx 格式"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
md = MarkItDown(enable_plugins=False)
bio = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=os.path.splitext(file_name)[1].lower(),
filename=file_name,
)
result = md.convert(bio, stream_info=stream_info)
return ParseResult(
text=result.markdown,
media=[],
)

View File

@@ -0,0 +1,100 @@
"""PDF 文件解析器
支持解析 PDF 文件中的文本和图片资源。
"""
import io
from pypdf import PdfReader
from astrbot.core.knowledge_base.parsers.base import (
BaseParser,
MediaItem,
ParseResult,
)
class PDFParser(BaseParser):
"""PDF 文档解析器
提取 PDF 中的文本内容和嵌入的图片资源。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析 PDF 文件
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 包含文本和图片的解析结果
"""
pdf_file = io.BytesIO(file_content)
reader = PdfReader(pdf_file)
text_parts = []
media_items = []
# 提取文本
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
# 提取图片
image_counter = 0
for page_num, page in enumerate(reader.pages):
try:
# 安全检查 Resources
if "/Resources" not in page:
continue
resources = page["/Resources"]
if not resources or "/XObject" not in resources: # type: ignore
continue
xobjects = resources["/XObject"].get_object() # type: ignore
if not xobjects:
continue
for obj_name in xobjects:
try:
obj = xobjects[obj_name]
if obj.get("/Subtype") != "/Image":
continue
# 提取图片数据
image_data = obj.get_data()
# 确定格式
filter_type = obj.get("/Filter", "")
if filter_type == "/DCTDecode":
ext = "jpg"
mime_type = "image/jpeg"
elif filter_type == "/FlateDecode":
ext = "png"
mime_type = "image/png"
else:
ext = "png"
mime_type = "image/png"
image_counter += 1
media_items.append(
MediaItem(
media_type="image",
file_name=f"page_{page_num}_img_{image_counter}.{ext}",
content=image_data,
mime_type=mime_type,
)
)
except Exception:
# 单个图片提取失败不影响整体
continue
except Exception:
# 页面处理失败不影响其他页面
continue
full_text = "\n\n".join(text_parts)
return ParseResult(text=full_text, media=media_items)

View File

@@ -0,0 +1,41 @@
"""文本文件解析器
支持解析 TXT 和 Markdown 文件。
"""
from astrbot.core.knowledge_base.parsers.base import BaseParser, ParseResult
class TextParser(BaseParser):
"""TXT/MD 文本解析器
支持多种字符编码的自动检测。
"""
async def parse(self, file_content: bytes, file_name: str) -> ParseResult:
"""解析文本文件
尝试使用多种编码解析文件内容。
Args:
file_content: 文件内容
file_name: 文件名
Returns:
ParseResult: 解析结果,不包含多媒体资源
Raises:
ValueError: 如果无法解码文件
"""
# 尝试多种编码
for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]:
try:
text = file_content.decode(encoding)
break
except UnicodeDecodeError:
continue
else:
raise ValueError(f"无法解码文件: {file_name}")
# 文本文件无多媒体资源
return ParseResult(text=text, media=[])

View File

@@ -0,0 +1,13 @@
from .base import BaseParser
async def select_parser(ext: str) -> BaseParser:
if ext in {".md", ".txt", ".markdown", ".xlsx", ".docx", ".xls"}:
from .markitdown_parser import MarkitdownParser
return MarkitdownParser()
elif ext == ".pdf":
from .pdf_parser import PDFParser
return PDFParser()
raise ValueError(f"暂时不支持的文件格式: {ext}")

View File

@@ -0,0 +1,16 @@
"""
检索模块
"""
from .manager import RetrievalManager, RetrievalResult
from .sparse_retriever import SparseRetriever, SparseResult
from .rank_fusion import RankFusion, FusedResult
__all__ = [
"RetrievalManager",
"RetrievalResult",
"SparseRetriever",
"SparseResult",
"RankFusion",
"FusedResult",
]

View File

@@ -0,0 +1,767 @@
———
》),
)÷(1-
”,
)、
:
&
*
一一
~~~~
.
.一
./
--
=″
[⑤]]
[①D]
ng昉
//
[②e]
[②g]
}
,也
[①⑥]
[②B]
[①a]
[④a]
[①③]
[③h]
③]
[②b]
×××
[①⑧]
[⑤b]
[②c]
[④b]
[②③]
[③a]
[④c]
[①⑤]
[①⑦]
[①g]
∈[
[①⑨]
[①④]
[①c]
[②f]
[②⑧]
[②①]
[①C]
[③c]
[③g]
[②⑤]
[②②]
一.
[①h]
.数
[①B]
数/
[①i]
[③e]
[①①]
[④d]
[④e]
[③b]
[⑤a]
[①A]
[②⑧]
[②⑦]
[①d]
[②j]
://
′∈
[②④
[⑤e]
...
...................
…………………………………………………③
[③F]
[①o]
]∧′=[
∪φ∈
②c
[③①]
[①E]
Ψ
.日
[②d]
[②
[②⑦]
[②②]
[③e]
[①i]
[①B]
[①h]
[①d]
[①g]
[①②]
[②a]
[⑩]
[①e]
[②h]
[②⑥]
[③d]
[②⑩]
元/吨
[②⑩]
[①]
::
[②]
[③]
[④]
[⑤]
[⑥]
[⑦]
[⑧]
[⑨]
……
——
?
,
'
?
·
———
──
?
<
>
[
]
(
)
-
+
×
/
В
"
;
#
@
γ
μ
φ
φ.
×
Δ
sub
exp
sup
sub
Lex
+ξ
-β
<±
<Δ
<λ
<φ
=
=☆
>λ
_
~±
[⑤f]
[⑤d]
[②i]
[②G]
[①f]
......
[③⑩]
第二
一番
一直
一个
一些
许多
有的是
也就是说
末##末
哎呀
哎哟
俺们
按照
吧哒
罢了
本着
比方
比如
鄙人
彼此
别的
别说
并且
不比
不成
不单
不但
不独
不管
不光
不过
不仅
不拘
不论
不怕
不然
不如
不特
不惟
不问
不只
朝着
趁着
除此之外
除非
除了
此间
此外
从而
但是
当着
的话
等等
叮咚
对于
多少
而况
而且
而是
而外
而言
而已
尔后
反过来
反过来说
反之
非但
非徒
否则
嘎登
各个
各位
各种
各自
根据
故此
固然
关于
果然
果真
哈哈
何处
何况
何时
哼唷
呼哧
还是
还有
换句话说
换言之
或是
或者
极了
及其
及至
即便
即或
即令
即若
即使
几时
既然
既是
继而
加之
假如
假若
假使
鉴于
较之
接着
结果
紧接着
进而
尽管
经过
就是
就是说
具体地说
具体说来
开始
开外
可见
可是
可以
况且
来着
例如
连同
两者
另外
另一方面
慢说
漫说
每当
莫若
某个
某些
哪边
哪儿
哪个
哪里
哪年
哪怕
哪天
哪些
哪样
那边
那儿
那个
那会儿
那里
那么
那么些
那么样
那时
那些
那样
乃至
你们
宁可
宁肯
宁愿
啪达
旁人
凭借
其次
其二
其他
其它
其一
其余
其中
起见
起见
岂但
恰恰相反
前后
前者
然而
然后
然则
人家
任何
任凭
如此
如果
如何
如其
如若
如上所述
若非
若是
上下
尚且
设若
设使
甚而
甚么
甚至
省得
时候
什么
什么样
使得
是的
首先
谁知
顺着
似的
虽然
虽说
虽则
随着
所以
他们
他人
它们
她们
倘或
倘然
倘若
倘使
通过
同时
万一
为何
为了
为什么
为着
嗡嗡
我们
呜呼
乌乎
无论
无宁
毋宁
相对而言
向着
沿
沿着
要不
要不然
要不是
要么
要是
也罢
也好
一般
一旦
一方面
一来
一切
一样
一则
依照
以便
以及
以免
以至
以至于
以致
抑或
因此
因而
因为
由此可见
由于
有的
有关
有些
于是
于是乎
与此同时
与否
与其
越是
云云
再说
再者
在下
咱们
怎么
怎么办
怎么样
怎样
照着
这边
这儿
这个
这会儿
这就是说
这里
这么
这么点儿
这么些
这么样
这时
这些
这样
正如
之类
之所以
之一
只是
只限
只要
只有
至于
诸位
着呢
自从
自个儿
自各儿
自己
自家
自身
综上所述
总的来看
总的来说
总的说来
总而言之
总之
纵令
纵然
纵使
遵照
作为
喔唷

View File

@@ -0,0 +1,273 @@
"""检索管理器
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
"""
import time
from dataclasses import dataclass
from typing import List
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from ..kb_helper import KBHelper
from astrbot import logger
@dataclass
class RetrievalResult:
"""检索结果"""
chunk_id: str
doc_id: str
doc_name: str
kb_id: str
kb_name: str
content: str
score: float
metadata: dict
class RetrievalManager:
"""检索管理器
职责:
- 协调稠密检索、稀疏检索和 Rerank
- 结果融合和排序
"""
def __init__(
self,
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBSQLiteDatabase,
):
"""初始化检索管理器
Args:
vec_db_factory: 向量数据库工厂
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
"""
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
async def retrieve(
self,
query: str,
kb_ids: List[str],
kb_id_helper_map: dict[str, KBHelper],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> List[RetrievalResult]:
"""混合检索
流程:
1. 稠密检索 (向量相似度)
2. 稀疏检索 (BM25)
3. 结果融合 (RRF)
4. Rerank 重排序
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_m_final: 最终返回数量
enable_rerank: 是否启用 Rerank
Returns:
List[RetrievalResult]: 检索结果列表
"""
if not kb_ids:
return []
kb_options: dict = {}
new_kb_ids = []
for kb_id in kb_ids:
kb_helper = kb_id_helper_map.get(kb_id)
if kb_helper:
kb = kb_helper.kb
kb_options[kb_id] = {
"top_k_dense": kb.top_k_dense or 50,
"top_k_sparse": kb.top_k_sparse or 50,
"top_m_final": kb.top_m_final or 5,
"vec_db": kb_helper.vec_db,
"rerank_provider_id": kb.rerank_provider_id,
}
new_kb_ids.append(kb_id)
else:
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
kb_ids = new_kb_ids
# 1. 稠密检索
time_start = time.time()
dense_results = await self._dense_retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
)
# 2. 稀疏检索
time_start = time.time()
sparse_results = await self.sparse_retriever.retrieve(
query=query,
kb_ids=kb_ids,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
)
# 3. 结果融合
time_start = time.time()
fused_results = await self.rank_fusion.fuse(
dense_results=dense_results,
sparse_results=sparse_results,
top_k=top_k_fusion,
)
time_end = time.time()
logger.debug(
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
)
# 4. 转换为 RetrievalResult (获取元数据)
retrieval_results = []
for fr in fused_results:
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
if metadata_dict:
retrieval_results.append(
RetrievalResult(
chunk_id=fr.chunk_id,
doc_id=fr.doc_id,
doc_name=metadata_dict["document"].doc_name,
kb_id=fr.kb_id,
kb_name=metadata_dict["knowledge_base"].kb_name,
content=fr.content,
score=fr.score,
metadata={
"chunk_index": fr.chunk_index,
"char_count": len(fr.content),
},
)
)
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
and vec_db.rerank_provider
and rerank_pi
and rerank_pi == vec_db.rerank_provider.meta().id
):
first_rerank = vec_db.rerank_provider
break
if first_rerank and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=first_rerank,
)
return retrieval_results[:top_m_final]
async def _dense_retrieve(
self,
query: str,
kb_ids: List[str],
kb_options: dict,
):
"""稠密检索 (向量相似度)
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
Returns:
List[Result]: 检索结果列表
"""
all_results: list[Result] = []
for kb_id in kb_ids:
if kb_id not in kb_options:
continue
try:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
dense_k = int(kb_options[kb_id]["top_k_dense"])
vec_results = await vec_db.retrieve(
query=query,
k=dense_k,
fetch_k=dense_k * 2,
rerank=False, # 稠密检索阶段不进行 rerank
metadata_filters={"kb_id": kb_id},
)
all_results.extend(vec_results)
except Exception as e:
from astrbot.core import logger
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
continue
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)
# return all_results[: len(all_results) // len(kb_ids)]
return all_results
async def _rerank(
self,
query: str,
results: List[RetrievalResult],
top_k: int,
rerank_provider: RerankProvider,
) -> List[RetrievalResult]:
"""Rerank 重排序
Args:
query: 查询文本
results: 检索结果列表
top_k: 返回结果数量
Returns:
List[RetrievalResult]: 重排序后的结果列表
"""
if not results:
return []
# 准备文档列表
docs = [r.content for r in results]
# 调用 Rerank Provider
rerank_results = await rerank_provider.rerank(
query=query,
documents=docs,
)
# 更新分数并重新排序
reranked_list = []
for rerank_result in rerank_results:
idx = rerank_result.index
if idx < len(results):
result = results[idx]
result.score = rerank_result.relevance_score
reranked_list.append(result)
reranked_list.sort(key=lambda x: x.score, reverse=True)
return reranked_list[:top_k]

View File

@@ -0,0 +1,138 @@
"""检索结果融合器
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
"""
import json
from dataclasses import dataclass
from astrbot.core.db.vec_db.base import Result
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
@dataclass
class FusedResult:
"""融合后的检索结果"""
chunk_id: str
chunk_index: int
doc_id: str
kb_id: str
content: str
score: float
class RankFusion:
"""检索结果融合器
职责:
- 融合稠密检索和稀疏检索的结果
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
"""初始化结果融合器
Args:
kb_db: 知识库数据库实例
k: RRF 参数,用于平滑排名
"""
self.kb_db = kb_db
self.k = k
async def fuse(
self,
dense_results: list[Result],
sparse_results: list[SparseResult],
top_k: int = 20,
) -> list[FusedResult]:
"""融合稠密和稀疏检索结果
RRF 公式:
score(doc) = sum(1 / (k + rank_i))
Args:
dense_results: 稠密检索结果
sparse_results: 稀疏检索结果
top_k: 返回结果数量
Returns:
List[FusedResult]: 融合后的结果列表
"""
# 1. 构建排名映射
dense_ranks = {
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
} # 这里的 doc_id 实际上是 chunk_id
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
# 2. 收集所有唯一的 ID
# 需要统一为 chunk_id
all_chunk_ids = set()
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
# 处理稀疏检索结果
for r in sparse_results:
all_chunk_ids.add(r.chunk_id)
chunk_id_to_sparse[r.chunk_id] = r
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
for r in dense_results:
vec_doc_id = r.data["doc_id"]
all_chunk_ids.add(vec_doc_id)
vec_doc_id_to_dense[vec_doc_id] = r
# 3. 计算 RRF 分数
rrf_scores: dict[str, float] = {}
for identifier in all_chunk_ids:
score = 0.0
# 来自稠密检索的贡献
if identifier in dense_ranks:
score += 1.0 / (self.k + dense_ranks[identifier])
# 来自稀疏检索的贡献
if identifier in sparse_ranks:
score += 1.0 / (self.k + sparse_ranks[identifier])
rrf_scores[identifier] = score
# 4. 排序
sorted_ids = sorted(
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
)[:top_k]
# 5. 构建融合结果
fused_results = []
for identifier in sorted_ids:
# 优先从稀疏检索获取完整信息
if identifier in chunk_id_to_sparse:
sr = chunk_id_to_sparse[identifier]
fused_results.append(
FusedResult(
chunk_id=sr.chunk_id,
chunk_index=sr.chunk_index,
doc_id=sr.doc_id,
kb_id=sr.kb_id,
content=sr.content,
score=rrf_scores[identifier],
)
)
elif identifier in vec_doc_id_to_dense:
# 从向量检索获取信息,需要从数据库获取块的详细信息
vec_result = vec_doc_id_to_dense[identifier]
chunk_md = json.loads(vec_result.data["metadata"])
fused_results.append(
FusedResult(
chunk_id=identifier,
chunk_index=chunk_md["chunk_index"],
doc_id=chunk_md["kb_doc_id"],
kb_id=chunk_md["kb_id"],
content=vec_result.data["text"],
score=rrf_scores[identifier],
)
)
return fused_results

View File

@@ -0,0 +1,130 @@
"""稀疏检索器
使用 BM25 算法进行基于关键词的文档检索
"""
import jieba
import os
import json
from dataclasses import dataclass
from rank_bm25 import BM25Okapi
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
@dataclass
class SparseResult:
"""稀疏检索结果"""
chunk_index: int
chunk_id: str
doc_id: str
kb_id: str
content: str
score: float
class SparseRetriever:
"""BM25 稀疏检索器
职责:
- 基于关键词的文档检索
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化稀疏检索器
Args:
kb_db: 知识库数据库实例
"""
self.kb_db = kb_db
self._index_cache = {} # 缓存 BM25 索引
with open(
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
encoding="utf-8",
) as f:
self.hit_stopwords = {
word.strip() for word in set(f.read().splitlines()) if word.strip()
}
async def retrieve(
self,
query: str,
kb_ids: list[str],
kb_options: dict,
) -> list[SparseResult]:
"""执行稀疏检索
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
kb_options: 每个知识库的检索选项
Returns:
List[SparseResult]: 检索结果列表
"""
# 1. 获取所有相关块
top_k_sparse = 0
chunks = []
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
if not vec_db:
continue
result = await vec_db.document_storage.get_documents(
metadata_filters={}, limit=None, offset=None
)
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
result = [
{
"chunk_id": doc["doc_id"],
"chunk_index": chunk_md["chunk_index"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": kb_id,
"text": doc["text"],
}
for doc, chunk_md in zip(result, chunk_mds)
]
chunks.extend(result)
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
if not chunks:
return []
# 2. 准备文档和索引
corpus = [chunk["text"] for chunk in chunks]
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
tokenized_corpus = [
[word for word in doc if word not in self.hit_stopwords]
for doc in tokenized_corpus
]
# 3. 构建 BM25 索引
bm25 = BM25Okapi(tokenized_corpus)
# 4. 执行检索
tokenized_query = list(jieba.cut(query))
tokenized_query = [
word for word in tokenized_query if word not in self.hit_stopwords
]
scores = bm25.get_scores(tokenized_query)
# 5. 排序并返回 Top-K
results = []
for idx, score in enumerate(scores):
chunk = chunks[idx]
results.append(
SparseResult(
chunk_id=chunk["chunk_id"],
chunk_index=chunk["chunk_index"],
doc_id=chunk["doc_id"],
kb_id=chunk["kb_id"],
content=chunk["text"],
score=float(score),
)
)
results.sort(key=lambda x: x.score, reverse=True)
# return results[: len(results) // len(kb_ids)]
return results[:top_k_sparse]

View File

@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str = None
self, event: AstrMessageEvent, check_text: str | None = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()

View File

@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
def check(self, content: str):
def check(self, content: str) -> tuple[bool, str]:
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
return False, ""

View File

@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
def check(self, content: str) -> bool:
def check(self, content: str) -> tuple[bool, str]:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"

View File

@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
handler: T.Callable[..., T.Awaitable[T.Any]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -36,6 +36,9 @@ async def call_handler(
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
@@ -94,5 +97,6 @@ async def call_event_hook(
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return True
return event.is_stopped()

View File

@@ -1,5 +1,6 @@
import traceback
import asyncio
import random
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
@@ -22,6 +23,26 @@ class PreProcessStage(Stage):
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""在处理事件之前的预处理"""
# 平台特异配置platform_specific.<platform>.pre_ack_emoji
supported = {"telegram", "lark"}
platform = event.get_platform_name()
cfg = (
self.config.get("platform_specific", {})
.get(platform, {})
.get("pre_ack_emoji", {})
) or {}
emojis = cfg.get("emojis") or []
if (
cfg.get("enable", False)
and platform in supported
and emojis
and event.is_at_or_wake_command
):
try:
await event.react(random.choice(emojis))
except Exception as e:
logger.warning(f"{platform} 预回应表情发送失败: {e}")
# 路径映射
if mappings := self.platform_settings.get("path_mapping", []):
# 支持 RecordImage 消息段的路径映射。
@@ -46,6 +67,9 @@ class PreProcessStage(Stage):
ctx = self.plugin_manager.context
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
)
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):

View File

@@ -6,7 +6,9 @@ import asyncio
import copy
import json
import traceback
from typing import AsyncGenerator, Union
from datetime import timedelta
from collections.abc import AsyncGenerator
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
@@ -31,6 +33,7 @@ from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_handler
from ..stage import Stage
from ..utils import inject_kb_context
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.star_handler import star_map
from astrbot.core.astr_agent_context import AstrAgentContext
@@ -42,7 +45,7 @@ except (ModuleNotFoundError, ImportError):
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@@ -100,7 +103,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
@@ -133,6 +136,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
@@ -148,7 +160,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
@@ -175,21 +187,33 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
handler=awaitable,
**tool_args,
)
async for resp in wrapper:
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
# async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
@@ -200,16 +224,23 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
read_timeout_seconds=timedelta(
seconds=run_context.context.tool_call_timeout
),
)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
@@ -271,19 +302,12 @@ async def run_agent(
except Exception as e:
logger.error(traceback.format_exc())
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
class LLMRequestSubStage(Stage):
@@ -300,6 +324,7 @@ class LLMRequestSubStage(Stage):
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
@@ -313,7 +338,7 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
def _select_provider(self, event: AstrMessageEvent):
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
@@ -325,7 +350,7 @@ class LLMRequestSubStage(Stage):
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent):
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
@@ -337,11 +362,13 @@ class LLMRequestSubStage(Stage):
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
@@ -356,6 +383,9 @@ class LLMRequestSubStage(Stage):
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
@@ -390,6 +420,14 @@ class LLMRequestSubStage(Stage):
if not req.prompt and not req.image_urls:
return
# 应用知识库
try:
await inject_kb_context(
umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
)
except Exception as e:
logger.error(f"调用知识库时遇到问题: {e}")
# 执行请求 LLM 前事件钩子。
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
@@ -444,13 +482,19 @@ class LLMRequestSubStage(Stage):
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
@@ -461,6 +505,7 @@ class LLMRequestSubStage(Stage):
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
tool_call_timeout=self.tool_call_timeout,
)
await agent_runner.reset(
provider=provider,
@@ -487,8 +532,10 @@ class LLMRequestSubStage(Stage):
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
else:
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
@@ -499,16 +546,29 @@ class LLMRequestSubStage(Stage):
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
if not req.conversation:
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
@@ -517,7 +577,23 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:]
if not latest_pair:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",

View File

@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
try:
wrapper = call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()

View File

@@ -0,0 +1,80 @@
from ..context import PipelineContext
from astrbot.core.provider.entities import ProviderRequest
from astrbot.api import logger, sp
async def inject_kb_context(
umo: str,
p_ctx: PipelineContext,
req: ProviderRequest,
) -> None:
"""inject knowledge base context into the provider request
Args:
umo: Unique message object (session ID)
p_ctx: Pipeline context
req: Provider request
"""
kb_mgr = p_ctx.plugin_manager.context.kb_manager
# 1. 优先读取会话级配置
session_config = await sp.session_get(umo, "kb_config", default={})
if session_config and "kb_ids" in session_config:
# 会话级配置
kb_ids = session_config.get("kb_ids", [])
# 如果配置为空列表,明确表示不使用知识库
if not kb_ids:
logger.info(f"[知识库] 会话 {umo} 已被配置为不使用知识库")
return
top_k = session_config.get("top_k", 5)
# 将 kb_ids 转换为 kb_names
kb_names = []
invalid_kb_ids = []
for kb_id in kb_ids:
kb_helper = await kb_mgr.get_kb(kb_id)
if kb_helper:
kb_names.append(kb_helper.kb.kb_name)
else:
logger.warning(f"[知识库] 知识库不存在或未加载: {kb_id}")
invalid_kb_ids.append(kb_id)
if invalid_kb_ids:
logger.warning(
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}"
)
if not kb_names:
return
logger.debug(f"[知识库] 使用会话级配置,知识库数量: {len(kb_names)}")
else:
kb_names = p_ctx.astrbot_config.get("kb_names", [])
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
logger.debug(f"[知识库] 使用全局配置,知识库数量: {len(kb_names)}")
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
if not kb_names:
return
logger.debug(f"[知识库] 开始检索知识库,数量: {len(kb_names)}, top_k={top_k}")
kb_context = await kb_mgr.retrieve(
query=req.prompt,
kb_names=kb_names,
top_k_fusion=top_k_fusion,
top_m_final=top_k,
)
if not kb_context:
return
formatted = kb_context.get("context_text", "")
if formatted:
results = kb_context.get("results", [])
logger.debug(f"[知识库] 为会话 {umo} 注入了 {len(results)} 条相关知识块")
req.system_prompt = f"{formatted}\n\n{req.system_prompt or ''}"

View File

@@ -1,17 +1,15 @@
import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from ..context import PipelineContext, call_event_hook
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import BaseMessageComponent, ComponentType
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@@ -114,6 +112,43 @@ class RespondStage(Stage):
# 如果所有组件都为空
return True
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
"""检查是否需要分段回复"""
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
return False
if event.get_platform_name() in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
return False
return True
def _extract_comp(
self,
raw_chain: list[BaseMessageComponent],
extract_types: set[ComponentType],
modify_raw_chain: bool = True,
):
extracted = []
if modify_raw_chain:
remaining = []
for comp in raw_chain:
if comp.type in extract_types:
extracted.append(comp)
else:
remaining.append(comp)
raw_chain[:] = remaining
else:
extracted = [comp for comp in raw_chain if comp.type in extract_types]
return extracted
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -123,7 +158,14 @@ class RespondStage(Stage):
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
logger.info(
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
if result.result_content_type == ResultContentType.STREAMING_RESULT:
if result.async_stream is None:
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
@@ -148,87 +190,81 @@ class RespondStage(Stage):
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
non_record_comps = [
c for c in result.chain if not isinstance(c, Comp.Record)
# 将 Plain 为空的消息段移除
result.chain = [
comp
for comp in result.chain
if not (
isinstance(comp, Comp.Plain)
and (not comp.text or not comp.text.strip())
)
]
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
# 发送消息链
# Record 需要强制单独发送
need_separately = {ComponentType.Record}
if self.is_seg_reply_required(event):
header_comps = self._extract_comp(
result.chain,
{ComponentType.Reply, ComponentType.At},
modify_raw_chain=True,
)
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
if not result.chain or len(result.chain) == 0:
# may fix #2670
logger.warning(
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
)
return
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
else:
for rcomp in record_comps:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}
for comp in result.chain
):
# may fix #2670
logger.warning(
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
)
return
sep_comps = self._extract_comp(
result.chain,
need_separately,
modify_raw_chain=True,
)
for comp in sep_comps:
chain = MessageChain([comp])
try:
await event.send(MessageChain([rcomp]))
await event.send(chain)
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
chain = MessageChain(result.chain)
if result.chain and len(result.chain) > 0:
try:
await event.send(chain)
except Exception as e:
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
try:
await event.send(MessageChain(non_record_comps))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
return
event.clear_result()

View File

@@ -183,56 +183,60 @@ class ResultDecorateStage(Stage):
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
and SessionServiceManager.should_process_tts_request(event)
):
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
if not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
)
else:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
)
new_chain.append(comp)
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
else:
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
result.chain = new_chain
# 文本转图片
elif (
@@ -275,7 +279,6 @@ class ResultDecorateStage(Stage):
result.chain = [Image.fromFileSystem(url)]
# 触发转发消息
has_forwarded = False
if event.get_platform_name() == "aiocqhttp":
word_cnt = 0
for comp in result.chain:
@@ -286,9 +289,9 @@ class ResultDecorateStage(Stage):
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
)
result.chain = [node]
has_forwarded = True
if not has_forwarded:
has_plain = any(isinstance(item, Plain) for item in result.chain)
if has_plain:
# at 回复
if (
self.reply_with_mention

View File

@@ -74,7 +74,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() == "webchat":
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
self.ctx = ctx
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
async def process(
self, event: AstrMessageEvent
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
conv_id = await self.conv_mgr.get_curr_conversation_id(
event.unified_msg_origin
)
if not conv_id:
await self.conv_mgr.new_conversation(
event.unified_msg_origin, platform_id=event.get_platform_id()
)
event.stop_event()

View File

@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
activated_handlers.append(handler)
if "parsed_params" in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
is_group_cmd_handler = any(
isinstance(f, CommandGroupFilter) for f in handler.event_filters
)
if not is_group_cmd_handler:
activated_handlers.append(handler)
if "parsed_params" in event.get_extra(default={}):
handlers_parsed_params[handler.handler_full_name] = (
event.get_extra("parsed_params")
)
event._extras.pop("parsed_params", None)

View File

@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator
from typing import List, Union, Optional, AsyncGenerator, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -49,7 +49,7 @@ class AstrMessageEvent(abc.ABC):
"""是否唤醒(是否通过 WakingStage)"""
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self._extras: dict[str, Any] = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
message_type=message_obj.type,
@@ -57,7 +57,7 @@ class AstrMessageEvent(abc.ABC):
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self._result: MessageEventResult = None
self._result: MessageEventResult | None = None
"""消息事件的结果"""
self._has_send_oper = False
@@ -90,8 +90,10 @@ class AstrMessageEvent(abc.ABC):
"""
return self.message_str
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str:
outline = ""
if not chain:
return outline
for i in chain:
if isinstance(i, Plain):
outline += i.text
@@ -173,13 +175,13 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(self, key=None):
def get_extra(self, key: str | None = None, default=None) -> Any:
"""
获取额外的信息。
"""
if key is None:
return self._extras
return self._extras.get(key, None)
return self._extras.get(key, default)
def clear_extra(self):
"""
@@ -261,6 +263,9 @@ class AstrMessageEvent(abc.ABC):
"""
if isinstance(result, str):
result = MessageEventResult().message(result)
# 兼容外部插件或调用方传入的 chain=None 的情况,确保为可迭代列表
if isinstance(result, MessageEventResult) and result.chain is None:
result.chain = []
self._result = result
def stop_event(self):
@@ -412,6 +417,16 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def react(self, emoji: str):
"""
对消息添加表情回应。
默认实现为发送一条包含该表情的消息。
注意:此实现并不一定符合所有平台的原生“表情回应”行为。
如需支持平台原生的消息反应功能,请在对应平台的子类中重写本方法。
"""
await self.send(MessageChain([Plain(emoji)]))
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息返回当前群聊的数据。

View File

@@ -55,7 +55,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group_id: str = "" # 群组id如果为私聊则为空
group: Group # 群组
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -64,6 +64,28 @@ class AstrBotMessage:
def __init__(self) -> None:
self.timestamp = int(time.time())
self.group = None
def __str__(self) -> str:
return str(self.__dict__)
@property
def group_id(self) -> str:
"""
向后兼容的 group_id 属性
群组id如果为私聊则为空
"""
if self.group:
return self.group.group_id
return ""
@group_id.setter
def group_id(self, value: str):
"""设置 group_id"""
if value:
if self.group:
self.group.group_id = value
else:
self.group = Group(group_id=value)
else:
self.group = None

View File

@@ -82,6 +82,10 @@ class PlatformManager:
from .sources.wecom.wecom_adapter import (
WecomPlatformAdapter, # noqa: F401
)
case "wecom_ai_bot":
from .sources.wecom_ai_bot.wecomai_adapter import (
WecomAIBotAdapter, # noqa: F401
)
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa: F401
@@ -90,6 +94,10 @@ class PlatformManager:
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "misskey":
from .sources.misskey.misskey_adapter import (
MisskeyPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
case "satori":

View File

@@ -14,3 +14,5 @@ class PlatformMetadata:
"""平台的默认配置模板"""
adapter_display_name: str = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""

View File

@@ -13,10 +13,12 @@ def register_platform_adapter(
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None,
logo_path: str = None,
):
"""用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
"""
def decorator(cls):
@@ -39,6 +41,7 @@ def register_platform_adapter(
description=desc,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls

View File

@@ -182,11 +182,13 @@ class AiocqhttpAdapter(Platform):
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
str(event.sender["user_id"]), event.sender["nickname"]
str(event.sender["user_id"]),
event.sender.get("card") or event.sender.get("nickname", "N/A"),
)
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:

View File

@@ -107,6 +107,22 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def react(self, emoji: str):
request = (
CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id)
.request_body(
CreateMessageReactionRequestBody.builder()
.reaction_type(Emoji.builder().emoji_type(emoji).build())
.build()
)
.build()
)
response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
return None
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:

View File

@@ -0,0 +1,727 @@
import asyncio
import random
from typing import Dict, Any, Optional, Awaitable, List
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import (
AstrBotMessage,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
import astrbot.api.message_components as Comp
from .misskey_api import MisskeyAPI
import os
try:
import magic # type: ignore
except Exception:
magic = None
from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import (
serialize_message_chain,
resolve_message_visibility,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
process_files,
extract_sender_info,
create_base_message,
process_at_mention,
format_poll,
cache_user_info,
cache_room_info,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
# Constants
MAX_FILE_UPLOAD_COUNT = 16
DEFAULT_UPLOAD_CONCURRENCY = 3
@register_platform_adapter("misskey", "Misskey 平台适配器")
class MisskeyPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config or {}
self.settings = platform_settings or {}
self.instance_url = self.config.get("misskey_instance_url", "")
self.access_token = self.config.get("misskey_token", "")
self.max_message_length = self.config.get("max_message_length", 3000)
self.default_visibility = self.config.get(
"misskey_default_visibility", "public"
)
self.local_only = self.config.get("misskey_local_only", False)
self.enable_chat = self.config.get("misskey_enable_chat", True)
self.enable_file_upload = self.config.get("misskey_enable_file_upload", True)
self.upload_folder = self.config.get("misskey_upload_folder")
# download / security related options (exposed to platform_config)
self.allow_insecure_downloads = bool(
self.config.get("misskey_allow_insecure_downloads", False)
)
# parse download timeout and chunk size safely
_dt = self.config.get("misskey_download_timeout")
try:
self.download_timeout = int(_dt) if _dt is not None else 15
except Exception:
self.download_timeout = 15
_chunk = self.config.get("misskey_download_chunk_size")
try:
self.download_chunk_size = int(_chunk) if _chunk is not None else 64 * 1024
except Exception:
self.download_chunk_size = 64 * 1024
# parse max download bytes safely
_md_bytes = self.config.get("misskey_max_download_bytes")
try:
self.max_download_bytes = int(_md_bytes) if _md_bytes is not None else None
except Exception:
self.max_download_bytes = None
self.unique_session = platform_settings["unique_session"]
self.api: Optional[MisskeyAPI] = None
self._running = False
self.client_self_id = ""
self._bot_username = ""
self._user_cache = {}
def meta(self) -> PlatformMetadata:
default_config = {
"misskey_instance_url": "",
"misskey_token": "",
"max_message_length": 3000,
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
# download / security options
"misskey_allow_insecure_downloads": False,
"misskey_download_timeout": 15,
"misskey_download_chunk_size": 65536,
"misskey_max_download_bytes": None,
}
default_config.update(self.config)
return PlatformMetadata(
name="misskey",
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
)
async def run(self):
if not self.instance_url or not self.access_token:
logger.error("[Misskey] 配置不完整,无法启动")
return
self.api = MisskeyAPI(
self.instance_url,
self.access_token,
allow_insecure_downloads=self.allow_insecure_downloads,
download_timeout=self.download_timeout,
chunk_size=self.download_chunk_size,
max_download_bytes=self.max_download_bytes,
)
self._running = True
try:
user_info = await self.api.get_current_user()
self.client_self_id = str(user_info.get("id", ""))
self._bot_username = user_info.get("username", "")
logger.info(
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
)
except Exception as e:
logger.error(f"[Misskey] 获取用户信息失败: {e}")
self._running = False
return
await self._start_websocket_connection()
def _register_event_handlers(self, streaming):
"""注册事件处理器"""
streaming.add_message_handler("notification", self._handle_notification)
streaming.add_message_handler("main:notification", self._handle_notification)
if self.enable_chat:
streaming.add_message_handler("newChatMessage", self._handle_chat_message)
streaming.add_message_handler(
"messaging:newChatMessage", self._handle_chat_message
)
streaming.add_message_handler("_debug", self._debug_handler)
async def _send_text_only_message(
self, session_id: str, text: str, session, message_chain
):
"""发送纯文本消息(无文件上传)"""
if not self.api:
return await super().send_by_session(session, message_chain)
if session_id and is_valid_user_session_id(session_id):
from .misskey_utils import extract_user_id_from_session_id
user_id = extract_user_id_from_session_id(session_id)
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
await self.api.send_message(payload)
elif session_id and is_valid_room_session_id(session_id):
from .misskey_utils import extract_room_id_from_session_id
room_id = extract_room_id_from_session_id(session_id)
payload = {"toRoomId": room_id, "text": text}
await self.api.send_room_message(payload)
return await super().send_by_session(session, message_chain)
def _process_poll_data(
self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str]
):
"""处理投票数据,将其添加到消息中"""
try:
if not isinstance(message.raw_message, dict):
message.raw_message = {}
message.raw_message["poll"] = poll
setattr(message, "poll", poll)
except Exception:
pass
poll_text = format_poll(poll)
if poll_text:
message.message.append(Comp.Plain(poll_text))
message_parts.append(poll_text)
def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]:
"""从会话和消息链中提取额外字段"""
fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None}
for comp in message_chain.chain:
if hasattr(comp, "cw") and getattr(comp, "cw", None):
fields["cw"] = getattr(comp, "cw")
break
if hasattr(session, "extra_data") and isinstance(
getattr(session, "extra_data", None), dict
):
extra_data = getattr(session, "extra_data")
fields.update(
{
"poll": extra_data.get("poll"),
"renote_id": extra_data.get("renote_id"),
"channel_id": extra_data.get("channel_id"),
}
)
return fields
async def _start_websocket_connection(self):
backoff_delay = 1.0
max_backoff = 300.0
backoff_multiplier = 1.5
connection_attempts = 0
while self._running:
try:
connection_attempts += 1
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
break
streaming = self.api.get_streaming_client()
self._register_event_handlers(streaming)
if await streaming.connect():
logger.info(
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
)
connection_attempts = 0
await streaming.subscribe_channel("main")
if self.enable_chat:
await streaming.subscribe_channel("messaging")
await streaming.subscribe_channel("messagingIndex")
logger.info("[Misskey] 聊天频道已订阅")
backoff_delay = 1.0
await streaming.listen()
else:
logger.error(
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
)
except Exception as e:
logger.error(
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
)
if self._running:
jitter = random.uniform(0, 1.0)
sleep_time = backoff_delay + jitter
logger.info(
f"[Misskey] {sleep_time:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
)
await asyncio.sleep(sleep_time)
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
async def _handle_notification(self, data: Dict[str, Any]):
try:
notification_type = data.get("type")
logger.debug(
f"[Misskey] 收到通知事件: type={notification_type}, user_id={data.get('userId', 'unknown')}"
)
if notification_type in ["mention", "reply", "quote"]:
note = data.get("note")
if note and self._is_bot_mentioned(note):
logger.info(
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
)
message = await self.convert_message(note)
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理通知失败: {e}")
async def _handle_chat_message(self, data: Dict[str, Any]):
try:
sender_id = str(
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
)
room_id = data.get("toRoomId")
logger.debug(
f"[Misskey] 收到聊天事件: sender_id={sender_id}, room_id={room_id}, is_self={sender_id == self.client_self_id}"
)
if sender_id == self.client_self_id:
return
if room_id:
raw_text = data.get("text", "")
logger.debug(
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
)
message = await self.convert_room_message(data)
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
else:
message = await self.convert_chat_message(data)
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
async def _debug_handler(self, data: Dict[str, Any]):
event_type = data.get("type", "unknown")
logger.debug(
f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}"
)
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
text = note.get("text", "")
if not text:
return False
mentions = note.get("mentions", [])
if self._bot_username and f"@{self._bot_username}" in text:
return True
if self.client_self_id in [str(uid) for uid in mentions]:
return True
reply = note.get("reply")
if reply and isinstance(reply, dict):
reply_user_id = str(reply.get("user", {}).get("id", ""))
if reply_user_id == self.client_self_id:
return bool(self._bot_username and f"@{self._bot_username}" in text)
return False
async def send_by_session(
self, session: MessageSession, message_chain: MessageChain
) -> Awaitable[Any]:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
try:
session_id = session.session_id
text, has_at_user = serialize_message_chain(message_chain.chain)
if not has_at_user and session_id:
# 从session_id中提取用户ID用于缓存查询
# session_id格式为: "chat%<user_id>" 或 "room%<room_id>" 或 "note%<user_id>"
user_id_for_cache = None
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2:
user_id_for_cache = parts[1]
user_info = None
if user_id_for_cache:
user_info = self._user_cache.get(user_id_for_cache)
text = add_at_mention_if_needed(text, user_info, has_at_user)
# 检查是否有文件组件
has_file_components = any(
isinstance(comp, Comp.Image)
or isinstance(comp, Comp.File)
or hasattr(comp, "convert_to_file_path")
or hasattr(comp, "get_file")
or any(
hasattr(comp, a) for a in ("file", "url", "path", "src", "source")
)
for comp in message_chain.chain
)
if not text or not text.strip():
if not has_file_components:
logger.warning("[Misskey] 消息内容为空且无文件组件,跳过发送")
return await super().send_by_session(session, message_chain)
else:
text = ""
if len(text) > self.max_message_length:
text = text[: self.max_message_length] + "..."
file_ids: List[str] = []
fallback_urls: List[str] = []
if not self.enable_file_upload:
return await self._send_text_only_message(
session_id, text, session, message_chain
)
MAX_UPLOAD_CONCURRENCY = 10
upload_concurrency = int(
self.config.get(
"misskey_upload_concurrency", DEFAULT_UPLOAD_CONCURRENCY
)
)
upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY)
sem = asyncio.Semaphore(upload_concurrency)
async def _upload_comp(comp) -> Optional[object]:
"""组件上传函数:处理 URL下载后上传或本地文件直接上传"""
from .misskey_utils import (
resolve_component_url_or_path,
upload_local_with_retries,
)
local_path = None
try:
async with sem:
if not self.api:
return None
# 解析组件的 URL 或本地路径
url_candidate, local_path = await resolve_component_url_or_path(
comp
)
if not url_candidate and not local_path:
return None
preferred_name = getattr(comp, "name", None) or getattr(
comp, "file", None
)
# URL 上传:下载后本地上传
if url_candidate:
result = await self.api.upload_and_find_file(
str(url_candidate),
preferred_name,
folder_id=self.upload_folder,
)
if isinstance(result, dict) and result.get("id"):
return str(result["id"])
# 本地文件上传
if local_path:
file_id = await upload_local_with_retries(
self.api,
str(local_path),
preferred_name,
self.upload_folder,
)
if file_id:
return file_id
# 所有上传都失败,尝试获取 URL 作为回退
if hasattr(comp, "register_to_file_service"):
try:
url = await comp.register_to_file_service()
if url:
return {"fallback_url": url}
except Exception:
pass
return None
finally:
# 清理临时文件
if local_path and isinstance(local_path, str):
data_temp = os.path.join(get_astrbot_data_path(), "temp")
if local_path.startswith(data_temp) and os.path.exists(
local_path
):
try:
os.remove(local_path)
logger.debug(f"[Misskey] 已清理临时文件: {local_path}")
except Exception:
pass
# 收集所有可能包含文件/URL信息的组件支持异步接口或同步字段
file_components = []
for comp in message_chain.chain:
try:
if (
isinstance(comp, Comp.Image)
or isinstance(comp, Comp.File)
or hasattr(comp, "convert_to_file_path")
or hasattr(comp, "get_file")
or any(
hasattr(comp, a)
for a in ("file", "url", "path", "src", "source")
)
):
file_components.append(comp)
except Exception:
# 保守跳过无法访问属性的组件
continue
if len(file_components) > MAX_FILE_UPLOAD_COUNT:
logger.warning(
f"[Misskey] 文件数量超过限制 ({len(file_components)} > {MAX_FILE_UPLOAD_COUNT}),只上传前{MAX_FILE_UPLOAD_COUNT}个文件"
)
file_components = file_components[:MAX_FILE_UPLOAD_COUNT]
upload_tasks = [_upload_comp(comp) for comp in file_components]
try:
results = await asyncio.gather(*upload_tasks) if upload_tasks else []
for r in results:
if not r:
continue
if isinstance(r, dict) and r.get("fallback_url"):
url = r.get("fallback_url")
if url:
fallback_urls.append(str(url))
else:
try:
fid_str = str(r)
if fid_str:
file_ids.append(fid_str)
except Exception:
pass
except Exception:
logger.debug("[Misskey] 并发上传过程中出现异常,继续发送文本")
if session_id and is_valid_room_session_id(session_id):
from .misskey_utils import extract_room_id_from_session_id
room_id = extract_room_id_from_session_id(session_id)
if fallback_urls:
appended = "\n" + "\n".join(fallback_urls)
text = (text or "") + appended
payload: Dict[str, Any] = {"toRoomId": room_id, "text": text}
if file_ids:
payload["fileIds"] = file_ids
await self.api.send_room_message(payload)
elif session_id:
from .misskey_utils import (
extract_user_id_from_session_id,
is_valid_chat_session_id,
)
if is_valid_chat_session_id(session_id):
user_id = extract_user_id_from_session_id(session_id)
if fallback_urls:
appended = "\n" + "\n".join(fallback_urls)
text = (text or "") + appended
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
if file_ids:
# 聊天消息只支持单个文件,使用 fileId 而不是 fileIds
payload["fileId"] = file_ids[0]
if len(file_ids) > 1:
logger.warning(
f"[Misskey] 聊天消息只支持单个文件,忽略其余 {len(file_ids) - 1} 个文件"
)
await self.api.send_message(payload)
else:
# 回退到发帖逻辑
# 去掉 session_id 中的 note% 前缀以匹配 user_cache 的键格式
user_id_for_cache = (
session_id.split("%")[1] if "%" in session_id else session_id
)
# 获取用户缓存信息包含reply_to_note_id
user_info_for_reply = self._user_cache.get(user_id_for_cache, {})
visibility, visible_user_ids = resolve_message_visibility(
user_id=user_id_for_cache,
user_cache=self._user_cache,
self_id=self.client_self_id,
default_visibility=self.default_visibility,
)
logger.debug(
f"[Misskey] 解析可见性: visibility={visibility}, visible_user_ids={visible_user_ids}, session_id={session_id}, user_id_for_cache={user_id_for_cache}"
)
fields = self._extract_additional_fields(session, message_chain)
if fallback_urls:
appended = "\n" + "\n".join(fallback_urls)
text = (text or "") + appended
# 从缓存中获取原消息ID作为reply_id
reply_id = user_info_for_reply.get("reply_to_note_id")
await self.api.create_note(
text=text,
visibility=visibility,
visible_user_ids=visible_user_ids,
file_ids=file_ids or None,
local_only=self.local_only,
reply_id=reply_id, # 添加reply_id参数
cw=fields["cw"],
poll=fields["poll"],
renote_id=fields["renote_id"],
channel_id=fields["channel_id"],
)
except Exception as e:
logger.error(f"[Misskey] 发送消息失败: {e}")
return await super().send_by_session(session, message_chain)
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=False)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
message_parts = []
raw_text = raw_data.get("text", "")
if raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
poll = raw_data.get("poll") or (
raw_data.get("note", {}).get("poll")
if isinstance(raw_data.get("note"), dict)
else None
)
if poll and isinstance(poll, dict):
self._process_poll_data(message, poll, message_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=True,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
)
raw_text = raw_data.get("text", "")
if raw_text:
message.message.append(Comp.Plain(raw_text))
files = raw_data.get("files", [])
process_files(message, files, include_text_parts=False)
message.message_str = raw_text if raw_text else ""
return message
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
room_id = raw_data.get("toRoomId", "")
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
room_id=room_id,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
cache_room_info(self._user_cache, raw_data, self.client_self_id)
raw_text = raw_data.get("text", "")
message_parts = []
if raw_text:
if self._bot_username and f"@{self._bot_username}" in raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def terminate(self):
self._running = False
if self.api:
await self.api.close()
def get_client(self) -> Any:
return self.api

View File

@@ -0,0 +1,940 @@
import json
import random
import asyncio
from typing import Any, Optional, Dict, List, Callable, Awaitable
import uuid
try:
import aiohttp
import websockets
except ImportError as e:
raise ImportError(
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
) from e
from astrbot.api import logger
from .misskey_utils import FileIDExtractor
# Constants
API_MAX_RETRIES = 3
HTTP_OK = 200
class APIError(Exception):
"""Misskey API 基础异常"""
pass
class APIConnectionError(APIError):
"""网络连接异常"""
pass
class APIRateLimitError(APIError):
"""API 频率限制异常"""
pass
class AuthenticationError(APIError):
"""认证失败异常"""
pass
class WebSocketError(APIError):
"""WebSocket 连接异常"""
pass
class StreamingClient:
def __init__(self, instance_url: str, access_token: str):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self.websocket: Optional[Any] = None
self.is_connected = False
self.message_handlers: Dict[str, Callable] = {}
self.channels: Dict[str, str] = {}
self.desired_channels: Dict[str, Optional[Dict]] = {}
self._running = False
self._last_pong = None
async def connect(self) -> bool:
try:
ws_url = self.instance_url.replace("https://", "wss://").replace(
"http://", "ws://"
)
ws_url += f"/streaming?i={self.access_token}"
self.websocket = await websockets.connect(
ws_url, ping_interval=30, ping_timeout=10
)
self.is_connected = True
self._running = True
logger.info("[Misskey WebSocket] 已连接")
if self.desired_channels:
try:
desired = list(self.desired_channels.items())
for channel_type, params in desired:
try:
await self.subscribe_channel(channel_type, params)
except Exception as e:
logger.warning(
f"[Misskey WebSocket] 重新订阅 {channel_type} 失败: {e}"
)
except Exception:
pass
return True
except Exception as e:
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
self.is_connected = False
return False
async def disconnect(self):
self._running = False
if self.websocket:
await self.websocket.close()
self.websocket = None
self.is_connected = False
logger.info("[Misskey WebSocket] 连接已断开")
async def subscribe_channel(
self, channel_type: str, params: Optional[Dict] = None
) -> str:
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
channel_id = str(uuid.uuid4())
message = {
"type": "connect",
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
}
await self.websocket.send(json.dumps(message))
self.channels[channel_id] = channel_type
return channel_id
async def unsubscribe_channel(self, channel_id: str):
if (
not self.is_connected
or not self.websocket
or channel_id not in self.channels
):
return
message = {"type": "disconnect", "body": {"id": channel_id}}
await self.websocket.send(json.dumps(message))
channel_type = self.channels.get(channel_id)
if channel_id in self.channels:
del self.channels[channel_id]
if channel_type and channel_type not in self.channels.values():
self.desired_channels.pop(channel_type, None)
def add_message_handler(
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
):
self.message_handlers[event_type] = handler
async def listen(self):
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
try:
async for message in self.websocket:
if not self._running:
break
try:
data = json.loads(message)
await self._handle_message(data)
except json.JSONDecodeError as e:
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
except Exception as e:
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
)
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
except websockets.exceptions.InvalidHandshake as e:
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
except Exception as e:
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
async def _handle_message(self, data: Dict[str, Any]):
message_type = data.get("type")
body = data.get("body", {})
def _build_channel_summary(message_type: Optional[str], body: Any) -> str:
try:
if not isinstance(body, dict):
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
inner = body.get("body") if isinstance(body.get("body"), dict) else body
note = (
inner.get("note")
if isinstance(inner, dict) and isinstance(inner.get("note"), dict)
else None
)
text = note.get("text") if note else None
note_id = note.get("id") if note else None
files = note.get("files") or [] if note else []
has_files = bool(files)
is_hidden = bool(note.get("isHidden")) if note else False
user = note.get("user", {}) if note else None
return (
f"[Misskey WebSocket] 收到消息类型: {message_type} | "
f"note_id={note_id} | user={user.get('username') if user else None} | "
f"text={text[:80] if text else '[no-text]'} | files={has_files} | hidden={is_hidden}"
)
except Exception:
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
channel_summary = _build_channel_summary(message_type, body)
logger.info(channel_summary)
if message_type == "channel":
channel_id = body.get("id")
event_type = body.get("type")
event_body = body.get("body", {})
logger.debug(
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
)
if channel_id in self.channels:
channel_type = self.channels[channel_id]
handler_key = f"{channel_type}:{event_type}"
if handler_key in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
await self.message_handlers[handler_key](event_body)
elif event_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
await self.message_handlers[event_type](event_body)
else:
logger.debug(
f"[Misskey WebSocket] 未找到处理器: {handler_key}{event_type}"
)
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](
{
"type": event_type,
"body": event_body,
"channel": channel_type,
}
)
elif message_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
await self.message_handlers[message_type](body)
else:
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](data)
def retry_async(
max_retries: int = 3,
retryable_exceptions: tuple = (APIConnectionError, APIRateLimitError),
backoff_base: float = 1.0,
max_backoff: float = 30.0,
):
"""
智能异步重试装饰器
Args:
max_retries: 最大重试次数
retryable_exceptions: 可重试的异常类型
backoff_base: 退避基数
max_backoff: 最大退避时间
"""
def decorator(func):
async def wrapper(*args, **kwargs):
last_exc = None
func_name = getattr(func, "__name__", "unknown")
for attempt in range(1, max_retries + 1):
try:
return await func(*args, **kwargs)
except retryable_exceptions as e:
last_exc = e
if attempt == max_retries:
logger.error(
f"[Misskey API] {func_name} 重试 {max_retries} 次后仍失败: {e}"
)
break
# 智能退避策略
if isinstance(e, APIRateLimitError):
# 频率限制用更长的退避时间
backoff = min(backoff_base * (3**attempt), max_backoff)
else:
# 其他错误用指数退避
backoff = min(backoff_base * (2**attempt), max_backoff)
jitter = random.uniform(0.1, 0.5) # 随机抖动
sleep_time = backoff + jitter
logger.warning(
f"[Misskey API] {func_name}{attempt} 次重试失败: {e}"
f"{sleep_time:.1f}s后重试"
)
await asyncio.sleep(sleep_time)
continue
except Exception as e:
# 非可重试异常直接抛出
logger.error(f"[Misskey API] {func_name} 遇到不可重试异常: {e}")
raise
if last_exc:
raise last_exc
return wrapper
return decorator
class MisskeyAPI:
def __init__(
self,
instance_url: str,
access_token: str,
*,
allow_insecure_downloads: bool = False,
download_timeout: int = 15,
chunk_size: int = 64 * 1024,
max_download_bytes: Optional[int] = None,
):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self._session: Optional[aiohttp.ClientSession] = None
self.streaming: Optional[StreamingClient] = None
# download options
self.allow_insecure_downloads = allow_insecure_downloads
self.download_timeout = download_timeout
self.chunk_size = chunk_size
self.max_download_bytes = (
int(max_download_bytes) if max_download_bytes is not None else None
)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return False
async def close(self) -> None:
if self.streaming:
await self.streaming.disconnect()
self.streaming = None
if self._session:
await self._session.close()
self._session = None
logger.debug("[Misskey API] 客户端已关闭")
def get_streaming_client(self) -> StreamingClient:
if not self.streaming:
self.streaming = StreamingClient(self.instance_url, self.access_token)
return self.streaming
@property
def session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
headers = {"Authorization": f"Bearer {self.access_token}"}
self._session = aiohttp.ClientSession(headers=headers)
return self._session
def _handle_response_status(self, status: int, endpoint: str):
"""处理 HTTP 响应状态码"""
if status == 400:
logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})")
raise APIError(f"Bad request for {endpoint}")
elif status == 401:
logger.error(f"[Misskey API] 未授权访问: {endpoint} (HTTP {status})")
raise AuthenticationError(f"Unauthorized access for {endpoint}")
elif status == 403:
logger.error(f"[Misskey API] 访问被禁止: {endpoint} (HTTP {status})")
raise AuthenticationError(f"Forbidden access for {endpoint}")
elif status == 404:
logger.error(f"[Misskey API] 资源不存在: {endpoint} (HTTP {status})")
raise APIError(f"Resource not found for {endpoint}")
elif status == 413:
logger.error(f"[Misskey API] 请求体过大: {endpoint} (HTTP {status})")
raise APIError(f"Request entity too large for {endpoint}")
elif status == 429:
logger.warning(f"[Misskey API] 请求频率限制: {endpoint} (HTTP {status})")
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
elif status == 500:
logger.error(f"[Misskey API] 服务器内部错误: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Internal server error for {endpoint}")
elif status == 502:
logger.error(f"[Misskey API] 网关错误: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Bad gateway for {endpoint}")
elif status == 503:
logger.error(f"[Misskey API] 服务不可用: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Service unavailable for {endpoint}")
elif status == 504:
logger.error(f"[Misskey API] 网关超时: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Gateway timeout for {endpoint}")
else:
logger.error(f"[Misskey API] 未知错误: {endpoint} (HTTP {status})")
raise APIConnectionError(f"HTTP {status} for {endpoint}")
async def _process_response(
self, response: aiohttp.ClientResponse, endpoint: str
) -> Any:
"""处理 API 响应"""
if response.status == HTTP_OK:
try:
result = await response.json()
if endpoint == "i/notifications":
notifications_data = (
result
if isinstance(result, list)
else result.get("notifications", [])
if isinstance(result, dict)
else []
)
if notifications_data:
logger.debug(
f"[Misskey API] 获取到 {len(notifications_data)} 条新通知"
)
else:
logger.debug(f"[Misskey API] 请求成功: {endpoint}")
return result
except json.JSONDecodeError as e:
logger.error(f"[Misskey API] 响应格式错误: {e}")
raise APIConnectionError("Invalid JSON response") from e
else:
try:
error_text = await response.text()
logger.error(
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}, 响应: {error_text}"
)
except Exception:
logger.error(
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}"
)
self._handle_response_status(response.status, endpoint)
raise APIConnectionError(f"Request failed for {endpoint}")
@retry_async(
max_retries=API_MAX_RETRIES,
retryable_exceptions=(APIConnectionError, APIRateLimitError),
)
async def _make_request(
self, endpoint: str, data: Optional[Dict[str, Any]] = None
) -> Any:
url = f"{self.instance_url}/api/{endpoint}"
payload = {"i": self.access_token}
if data:
payload.update(data)
try:
async with self.session.post(url, json=payload) as response:
return await self._process_response(response, endpoint)
except aiohttp.ClientError as e:
logger.error(f"[Misskey API] HTTP 请求错误: {e}")
raise APIConnectionError(f"HTTP request failed: {e}") from e
async def create_note(
self,
text: Optional[str] = None,
visibility: str = "public",
reply_id: Optional[str] = None,
visible_user_ids: Optional[List[str]] = None,
file_ids: Optional[List[str]] = None,
local_only: bool = False,
cw: Optional[str] = None,
poll: Optional[Dict[str, Any]] = None,
renote_id: Optional[str] = None,
channel_id: Optional[str] = None,
reaction_acceptance: Optional[str] = None,
no_extract_mentions: Optional[bool] = None,
no_extract_hashtags: Optional[bool] = None,
no_extract_emojis: Optional[bool] = None,
media_ids: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API."""
data: Dict[str, Any] = {}
if text is not None:
data["text"] = text
data["visibility"] = visibility
data["localOnly"] = local_only
if reply_id:
data["replyId"] = reply_id
if visible_user_ids and visibility == "specified":
data["visibleUserIds"] = visible_user_ids
if file_ids:
data["fileIds"] = file_ids
if media_ids:
data["mediaIds"] = media_ids
if cw is not None:
data["cw"] = cw
if poll is not None:
data["poll"] = poll
if renote_id is not None:
data["renoteId"] = renote_id
if channel_id is not None:
data["channelId"] = channel_id
if reaction_acceptance is not None:
data["reactionAcceptance"] = reaction_acceptance
if no_extract_mentions is not None:
data["noExtractMentions"] = bool(no_extract_mentions)
if no_extract_hashtags is not None:
data["noExtractHashtags"] = bool(no_extract_hashtags)
if no_extract_emojis is not None:
data["noExtractEmojis"] = bool(no_extract_emojis)
result = await self._make_request("notes/create", data)
note_id = (
result.get("createdNote", {}).get("id", "unknown")
if isinstance(result, dict)
else "unknown"
)
logger.debug(f"[Misskey API] 发帖成功: {note_id}")
return result
async def upload_file(
self,
file_path: str,
name: Optional[str] = None,
folder_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Upload a file to Misskey drive/files/create and return a dict containing id and raw result."""
if not file_path:
raise APIError("No file path provided for upload")
url = f"{self.instance_url}/api/drive/files/create"
form = aiohttp.FormData()
form.add_field("i", self.access_token)
try:
filename = name or file_path.split("/")[-1]
if folder_id:
form.add_field("folderId", str(folder_id))
try:
f = open(file_path, "rb")
except FileNotFoundError as e:
logger.error(f"[Misskey API] 本地文件不存在: {file_path}")
raise APIError(f"File not found: {file_path}") from e
try:
form.add_field("file", f, filename=filename)
async with self.session.post(url, data=form) as resp:
result = await self._process_response(resp, "drive/files/create")
file_id = FileIDExtractor.extract_file_id(result)
logger.debug(
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}"
)
return {"id": file_id, "raw": result}
finally:
f.close()
except aiohttp.ClientError as e:
logger.error(f"[Misskey API] 文件上传网络错误: {e}")
raise APIConnectionError(f"Upload failed: {e}") from e
async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]:
"""Find files by MD5 hash"""
if not md5_hash:
raise APIError("No MD5 hash provided for find-by-hash")
data = {"md5": md5_hash}
try:
logger.debug(f"[Misskey API] find-by-hash 请求: md5={md5_hash}")
result = await self._make_request("drive/files/find-by-hash", data)
logger.debug(
f"[Misskey API] find-by-hash 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
)
return result if isinstance(result, list) else []
except Exception as e:
logger.error(f"[Misskey API] 根据哈希查找文件失败: {e}")
raise
async def find_files_by_name(
self, name: str, folder_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Find files by name"""
if not name:
raise APIError("No name provided for find")
data: Dict[str, Any] = {"name": name}
if folder_id:
data["folderId"] = folder_id
try:
logger.debug(f"[Misskey API] find 请求: name={name}, folder_id={folder_id}")
result = await self._make_request("drive/files/find", data)
logger.debug(
f"[Misskey API] find 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
)
return result if isinstance(result, list) else []
except Exception as e:
logger.error(f"[Misskey API] 根据名称查找文件失败: {e}")
raise
async def find_files(
self,
limit: int = 10,
folder_id: Optional[str] = None,
type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""List files with optional filters"""
data: Dict[str, Any] = {"limit": limit}
if folder_id is not None:
data["folderId"] = folder_id
if type is not None:
data["type"] = type
try:
logger.debug(
f"[Misskey API] 列表文件请求: limit={limit}, folder_id={folder_id}, type={type}"
)
result = await self._make_request("drive/files", data)
logger.debug(
f"[Misskey API] 列表文件响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
)
return result if isinstance(result, list) else []
except Exception as e:
logger.error(f"[Misskey API] 列表文件失败: {e}")
raise
async def _download_with_existing_session(
self, url: str, ssl_verify: bool = True
) -> Optional[bytes]:
"""使用现有会话下载文件"""
if not (hasattr(self, "session") and self.session):
raise APIConnectionError("No existing session available")
async with self.session.get(
url, timeout=aiohttp.ClientTimeout(total=15), ssl=ssl_verify
) as response:
if response.status == 200:
return await response.read()
return None
async def _download_with_temp_session(
self, url: str, ssl_verify: bool = True
) -> Optional[bytes]:
"""使用临时会话下载文件"""
connector = aiohttp.TCPConnector(ssl=ssl_verify)
async with aiohttp.ClientSession(connector=connector) as temp_session:
async with temp_session.get(
url, timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status == 200:
return await response.read()
return None
async def upload_and_find_file(
self,
url: str,
name: Optional[str] = None,
folder_id: Optional[str] = None,
max_wait_time: float = 30.0,
check_interval: float = 2.0,
) -> Optional[Dict[str, Any]]:
"""
简化的文件上传:尝试 URL 上传,失败则下载后本地上传
Args:
url: 文件URL
name: 文件名(可选)
folder_id: 文件夹ID可选
max_wait_time: 保留参数(未使用)
check_interval: 保留参数(未使用)
Returns:
包含文件ID和元信息的字典失败时返回None
"""
if not url:
raise APIError("URL不能为空")
# 通过本地上传获取即时文件 ID下载文件 → 上传 → 返回 ID
try:
import tempfile
import os
# SSL 验证下载,失败则重试不验证 SSL
tmp_bytes = None
try:
tmp_bytes = await self._download_with_existing_session(
url, ssl_verify=True
) or await self._download_with_temp_session(url, ssl_verify=True)
except Exception as ssl_error:
logger.debug(
f"[Misskey API] SSL 验证下载失败: {ssl_error},重试不验证 SSL"
)
try:
tmp_bytes = await self._download_with_existing_session(
url, ssl_verify=False
) or await self._download_with_temp_session(url, ssl_verify=False)
except Exception:
pass
if tmp_bytes:
with tempfile.NamedTemporaryFile(delete=False) as tmpf:
tmpf.write(tmp_bytes)
tmp_path = tmpf.name
try:
result = await self.upload_file(tmp_path, name, folder_id)
logger.debug(f"[Misskey API] 本地上传成功: {result.get('id')}")
return result
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
except Exception as e:
logger.error(f"[Misskey API] 本地上传失败: {e}")
return None
async def get_current_user(self) -> Dict[str, Any]:
"""获取当前用户信息"""
return await self._make_request("i", {})
async def send_message(
self, user_id_or_payload: Any, text: Optional[str] = None
) -> Dict[str, Any]:
"""发送聊天消息。
Accepts either (user_id: str, text: str) or a single dict payload prepared by caller.
"""
if isinstance(user_id_or_payload, dict):
data = user_id_or_payload
else:
data = {"toUserId": user_id_or_payload, "text": text}
result = await self._make_request("chat/messages/create-to-user", data)
message_id = result.get("id", "unknown")
logger.debug(f"[Misskey API] 聊天消息发送成功: {message_id}")
return result
async def send_room_message(
self, room_id_or_payload: Any, text: Optional[str] = None
) -> Dict[str, Any]:
"""发送房间消息。
Accepts either (room_id: str, text: str) or a single dict payload.
"""
if isinstance(room_id_or_payload, dict):
data = room_id_or_payload
else:
data = {"toRoomId": room_id_or_payload, "text": text}
result = await self._make_request("chat/messages/create-to-room", data)
message_id = result.get("id", "unknown")
logger.debug(f"[Misskey API] 房间消息发送成功: {message_id}")
return result
async def get_messages(
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取聊天消息历史"""
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
if since_id:
data["sinceId"] = since_id
result = await self._make_request("chat/messages/user-timeline", data)
if isinstance(result, list):
return result
logger.warning(f"[Misskey API] 聊天消息响应格式异常: {type(result)}")
return []
async def get_mentions(
self, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取提及通知"""
data: Dict[str, Any] = {"limit": limit}
if since_id:
data["sinceId"] = since_id
data["includeTypes"] = ["mention", "reply", "quote"]
result = await self._make_request("i/notifications", data)
if isinstance(result, list):
return result
elif isinstance(result, dict) and "notifications" in result:
return result["notifications"]
else:
logger.warning(f"[Misskey API] 提及通知响应格式异常: {type(result)}")
return []
async def send_message_with_media(
self,
message_type: str,
target_id: str,
text: Optional[str] = None,
media_urls: Optional[List[str]] = None,
local_files: Optional[List[str]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
通用消息发送函数:统一处理文本+媒体发送
Args:
message_type: 消息类型 ('chat', 'room', 'note')
target_id: 目标ID (用户ID/房间ID/频道ID等)
text: 文本内容
media_urls: 媒体文件URL列表
local_files: 本地文件路径列表
**kwargs: 其他参数如visibility等
Returns:
发送结果字典
Raises:
APIError: 参数错误或发送失败
"""
if not text and not media_urls and not local_files:
raise APIError("消息内容不能为空:需要文本或媒体文件")
file_ids = []
# 处理远程媒体文件
if media_urls:
file_ids.extend(await self._process_media_urls(media_urls))
# 处理本地文件
if local_files:
file_ids.extend(await self._process_local_files(local_files))
# 根据消息类型发送
return await self._dispatch_message(
message_type, target_id, text, file_ids, **kwargs
)
async def _process_media_urls(self, urls: List[str]) -> List[str]:
"""处理远程媒体文件URL列表返回文件ID列表"""
file_ids = []
for url in urls:
try:
result = await self.upload_and_find_file(url)
if result and result.get("id"):
file_ids.append(result["id"])
logger.debug(f"[Misskey API] URL媒体上传成功: {result['id']}")
else:
logger.error(f"[Misskey API] URL媒体上传失败: {url}")
except Exception as e:
logger.error(f"[Misskey API] URL媒体处理失败 {url}: {e}")
# 继续处理其他文件,不中断整个流程
continue
return file_ids
async def _process_local_files(self, file_paths: List[str]) -> List[str]:
"""处理本地文件路径列表返回文件ID列表"""
file_ids = []
for file_path in file_paths:
try:
result = await self.upload_file(file_path)
if result and result.get("id"):
file_ids.append(result["id"])
logger.debug(f"[Misskey API] 本地文件上传成功: {result['id']}")
else:
logger.error(f"[Misskey API] 本地文件上传失败: {file_path}")
except Exception as e:
logger.error(f"[Misskey API] 本地文件处理失败 {file_path}: {e}")
continue
return file_ids
async def _dispatch_message(
self,
message_type: str,
target_id: str,
text: Optional[str],
file_ids: List[str],
**kwargs,
) -> Dict[str, Any]:
"""根据消息类型分发到对应的发送方法"""
if message_type == "chat":
# 聊天消息使用 fileId (单数)
payload = {"toUserId": target_id}
if text:
payload["text"] = text
if file_ids:
if len(file_ids) == 1:
payload["fileId"] = file_ids[0]
else:
# 多文件时逐个发送
results = []
for file_id in file_ids:
single_payload = payload.copy()
single_payload["fileId"] = file_id
result = await self.send_message(single_payload)
results.append(result)
return {"multiple": True, "results": results}
return await self.send_message(payload)
elif message_type == "room":
# 房间消息使用 fileId (单数)
payload = {"toRoomId": target_id}
if text:
payload["text"] = text
if file_ids:
if len(file_ids) == 1:
payload["fileId"] = file_ids[0]
else:
# 多文件时逐个发送
results = []
for file_id in file_ids:
single_payload = payload.copy()
single_payload["fileId"] = file_id
result = await self.send_room_message(single_payload)
results.append(result)
return {"multiple": True, "results": results}
return await self.send_room_message(payload)
elif message_type == "note":
# 发帖使用 fileIds (复数)
note_kwargs = {
"text": text,
"file_ids": file_ids or None,
}
# 合并其他参数
note_kwargs.update(kwargs)
return await self.create_note(**note_kwargs)
else:
raise APIError(f"不支持的消息类型: {message_type}")

View File

@@ -0,0 +1,158 @@
import asyncio
import re
from typing import AsyncGenerator
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
from astrbot.api.message_components import Plain
from .misskey_utils import (
serialize_message_chain,
resolve_visibility_from_raw_message,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
extract_user_id_from_session_id,
extract_room_id_from_session_id,
)
class MisskeyPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
def _is_system_command(self, message_str: str) -> bool:
"""检测是否为系统指令"""
if not message_str or not message_str.strip():
return False
system_prefixes = ["/", "!", "#", ".", "^"]
message_trimmed = message_str.strip()
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
async def send(self, message: MessageChain):
"""发送消息,使用适配器的完整上传和发送逻辑"""
try:
logger.debug(
f"[MisskeyEvent] send 方法被调用,消息链包含 {len(message.chain)} 个组件"
)
# 使用适配器的 send_by_session 方法,它包含文件上传逻辑
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.platform.message_type import MessageType
# 根据session_id类型确定消息类型
if is_valid_user_session_id(self.session_id):
message_type = MessageType.FRIEND_MESSAGE
elif is_valid_room_session_id(self.session_id):
message_type = MessageType.GROUP_MESSAGE
else:
message_type = MessageType.FRIEND_MESSAGE # 默认
session = MessageSession(
platform_name=self.platform_meta.name,
message_type=message_type,
session_id=self.session_id,
)
logger.debug(
f"[MisskeyEvent] 检查适配器方法: hasattr(self.client, 'send_by_session') = {hasattr(self.client, 'send_by_session')}"
)
# 调用适配器的 send_by_session 方法
if hasattr(self.client, "send_by_session"):
logger.debug("[MisskeyEvent] 调用适配器的 send_by_session 方法")
await self.client.send_by_session(session, message)
else:
# 回退到原来的简化发送逻辑
content, has_at = serialize_message_chain(message.chain)
if not content:
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
return
original_message_id = getattr(self.message_obj, "message_id", None)
raw_message = getattr(self.message_obj, "raw_message", {})
if raw_message and not has_at:
user_data = raw_message.get("user", {})
user_info = {
"username": user_data.get("username", ""),
"nickname": user_data.get(
"name", user_data.get("username", "")
),
}
content = add_at_mention_if_needed(content, user_info, has_at)
# 根据会话类型选择发送方式
if hasattr(self.client, "send_message") and is_valid_user_session_id(
self.session_id
):
user_id = extract_user_id_from_session_id(self.session_id)
await self.client.send_message(user_id, content)
elif hasattr(
self.client, "send_room_message"
) and is_valid_room_session_id(self.session_id):
room_id = extract_room_id_from_session_id(self.session_id)
await self.client.send_room_message(room_id, content)
elif original_message_id and hasattr(self.client, "create_note"):
visibility, visible_user_ids = resolve_visibility_from_raw_message(
raw_message
)
await self.client.create_note(
content,
reply_id=original_message_id,
visibility=visibility,
visible_user_ids=visible_user_ids,
)
elif hasattr(self.client, "create_note"):
logger.debug("[MisskeyEvent] 创建新帖子")
await self.client.create_note(content)
await super().send(message)
except Exception as e:
logger.error(f"[MisskeyEvent] 发送失败: {e}")
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,538 @@
"""Misskey 平台适配器通用工具函数"""
from typing import Dict, Any, List, Tuple, Optional, Union
import astrbot.api.message_components as Comp
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
class FileIDExtractor:
"""从 API 响应中提取文件 ID 的帮助类(无状态)。"""
@staticmethod
def extract_file_id(result: Any) -> Optional[str]:
if not isinstance(result, dict):
return None
id_paths = [
lambda r: r.get("createdFile", {}).get("id"),
lambda r: r.get("file", {}).get("id"),
lambda r: r.get("id"),
]
for p in id_paths:
try:
if fid := p(result):
return fid
except Exception:
continue
return None
class MessagePayloadBuilder:
"""构建不同类型消息负载的帮助类(无状态)。"""
@staticmethod
def build_chat_payload(
user_id: str, text: Optional[str], file_id: Optional[str] = None
) -> Dict[str, Any]:
payload = {"toUserId": user_id}
if text:
payload["text"] = text
if file_id:
payload["fileId"] = file_id
return payload
@staticmethod
def build_room_payload(
room_id: str, text: Optional[str], file_id: Optional[str] = None
) -> Dict[str, Any]:
payload = {"toRoomId": room_id}
if text:
payload["text"] = text
if file_id:
payload["fileId"] = file_id
return payload
@staticmethod
def build_note_payload(
text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs
) -> Dict[str, Any]:
payload: Dict[str, Any] = {}
if text:
payload["text"] = text
if file_ids:
payload["fileIds"] = file_ids
payload |= kwargs
return payload
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
"""将消息链序列化为文本字符串"""
text_parts = []
has_at = False
def process_component(component):
nonlocal has_at
if isinstance(component, Comp.Plain):
return component.text
elif isinstance(component, Comp.File):
# 为文件组件返回占位符,但适配器仍会处理原组件
return "[文件]"
elif isinstance(component, Comp.Image):
# 为图片组件返回占位符,但适配器仍会处理原组件
return "[图片]"
elif isinstance(component, Comp.At):
has_at = True
# 优先使用name字段用户名如果没有则使用qq字段
# 这样可以避免在Misskey中生成 @<user_id> 这样的无效提及
if hasattr(component, "name") and component.name:
return f"@{component.name}"
else:
return f"@{component.qq}"
elif hasattr(component, "text"):
text = getattr(component, "text", "")
if "@" in text:
has_at = True
return text
else:
return str(component)
for component in chain:
if isinstance(component, Comp.Node) and component.content:
for node_comp in component.content:
result = process_component(node_comp)
if result:
text_parts.append(result)
else:
result = process_component(component)
if result:
text_parts.append(result)
return "".join(text_parts), has_at
def resolve_message_visibility(
user_id: Optional[str] = None,
user_cache: Optional[Dict[str, Any]] = None,
self_id: Optional[str] = None,
raw_message: Optional[Dict[str, Any]] = None,
default_visibility: str = "public",
) -> Tuple[str, Optional[List[str]]]:
"""解析 Misskey 消息的可见性设置
可以从 user_cache 或 raw_message 中解析,支持两种调用方式:
1. 基于 user_cache: resolve_message_visibility(user_id, user_cache, self_id)
2. 基于 raw_message: resolve_message_visibility(raw_message=raw_message, self_id=self_id)
"""
visibility = default_visibility
visible_user_ids = None
# 优先从 user_cache 解析
if user_id and user_cache:
user_info = user_cache.get(user_id)
if user_info:
original_visibility = user_info.get("visibility", default_visibility)
if original_visibility == "specified":
visibility = "specified"
original_visible_users = user_info.get("visible_user_ids", [])
users_to_include = [user_id]
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
# 回退到从 raw_message 解析
if raw_message:
original_visibility = raw_message.get("visibility", default_visibility)
if original_visibility == "specified":
visibility = "specified"
original_visible_users = raw_message.get("visibleUserIds", [])
sender_id = raw_message.get("userId", "")
users_to_include = []
if sender_id:
users_to_include.append(sender_id)
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
# 保留旧函数名作为向后兼容的别名
def resolve_visibility_from_raw_message(
raw_message: Dict[str, Any], self_id: Optional[str] = None
) -> Tuple[str, Optional[List[str]]]:
"""从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)"""
return resolve_message_visibility(raw_message=raw_message, self_id=self_id)
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "chat"
and bool(parts[1])
and parts[1] != "unknown"
)
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "room"
and bool(parts[1])
and parts[1] != "unknown"
)
def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "chat"
and bool(parts[1])
and parts[1] != "unknown"
)
def extract_user_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取用户 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2:
return parts[1]
return session_id
def extract_room_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取房间 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2 and parts[0] == "room":
return parts[1]
return session_id
def add_at_mention_if_needed(
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
) -> str:
"""如果需要且没有@用户,则添加@用户
注意仅在有有效的username时才添加@提及避免使用用户ID
"""
if has_at or not user_info:
return text
username = user_info.get("username")
# 如果没有username则不添加@提及,返回原文本
# 这样可以避免生成 @<user_id> 这样的无效提及
if not username:
return text
mention = f"@{username}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
return text
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
"""创建文件组件和描述文本"""
file_url = file_info.get("url", "")
file_name = file_info.get("name", "未知文件")
file_type = file_info.get("type", "")
if file_type.startswith("image/"):
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
elif file_type.startswith("audio/"):
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
elif file_type.startswith("video/"):
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
else:
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
def process_files(
message: AstrBotMessage, files: list, include_text_parts: bool = True
) -> list:
"""处理文件列表,添加到消息组件中并返回文本描述"""
file_parts = []
for file_info in files:
component, part_text = create_file_component(file_info)
message.message.append(component)
if include_text_parts:
file_parts.append(part_text)
return file_parts
def format_poll(poll: Dict[str, Any]) -> str:
"""将 Misskey 的 poll 对象格式化为可读字符串。"""
if not poll or not isinstance(poll, dict):
return ""
multiple = poll.get("multiple", False)
choices = poll.get("choices", [])
text_choices = [
f"({idx}) {c.get('text', '')} [{c.get('votes', 0)}票]"
for idx, c in enumerate(choices, start=1)
]
parts = ["[投票]", ("允许多选" if multiple else "单选")] + (
["选项: " + ", ".join(text_choices)] if text_choices else []
)
return " ".join(parts)
def extract_sender_info(
raw_data: Dict[str, Any], is_chat: bool = False
) -> Dict[str, Any]:
"""提取发送者信息"""
if is_chat:
sender = raw_data.get("fromUser", {})
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
else:
sender = raw_data.get("user", {})
sender_id = str(sender.get("id", ""))
return {
"sender": sender,
"sender_id": sender_id,
"nickname": sender.get("name", sender.get("username", "")),
"username": sender.get("username", ""),
}
def create_base_message(
raw_data: Dict[str, Any],
sender_info: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
room_id: Optional[str] = None,
unique_session: bool = False,
) -> AstrBotMessage:
"""创建基础消息对象"""
message = AstrBotMessage()
message.raw_message = raw_data
message.message = []
message.sender = MessageMember(
user_id=sender_info["sender_id"],
nickname=sender_info["nickname"],
)
if room_id:
session_prefix = "room"
session_id = f"{session_prefix}%{room_id}"
if unique_session:
session_id += f"_{sender_info['sender_id']}"
message.type = MessageType.GROUP_MESSAGE
message.group_id = room_id
elif is_chat:
session_prefix = "chat"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
else:
session_prefix = "note"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.OTHER_MESSAGE
message.session_id = (
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
)
message.message_id = str(raw_data.get("id", ""))
message.self_id = client_self_id
return message
def process_at_mention(
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
) -> Tuple[List[str], str]:
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
message_parts = []
if not raw_text:
return message_parts, ""
if bot_username and raw_text.startswith(f"@{bot_username}"):
at_mention = f"@{bot_username}"
message.message.append(Comp.At(qq=client_self_id))
remaining_text = raw_text[len(at_mention) :].strip()
if remaining_text:
message.message.append(Comp.Plain(remaining_text))
message_parts.append(remaining_text)
return message_parts, remaining_text
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
return message_parts, raw_text
def cache_user_info(
user_cache: Dict[str, Any],
sender_info: Dict[str, Any],
raw_data: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
):
"""缓存用户信息"""
if is_chat:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": "specified",
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
}
else:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": raw_data.get("visibility", "public"),
"visible_user_ids": raw_data.get("visibleUserIds", []),
# 保存原消息ID用于回复时作为reply_id
"reply_to_note_id": raw_data.get("id"),
}
user_cache[sender_info["sender_id"]] = user_cache_data
def cache_room_info(
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
):
"""缓存房间信息"""
room_data = raw_data.get("toRoom")
room_id = raw_data.get("toRoomId")
if room_data and room_id:
room_cache_key = f"room:{room_id}"
user_cache[room_cache_key] = {
"room_id": room_id,
"room_name": room_data.get("name", ""),
"room_description": room_data.get("description", ""),
"owner_id": room_data.get("ownerId", ""),
"visibility": "specified",
"visible_user_ids": [client_self_id],
}
async def resolve_component_url_or_path(
comp: Any,
) -> Tuple[Optional[str], Optional[str]]:
"""尝试从组件解析可上传的远程 URL 或本地路径。
返回 (url_candidate, local_path)。两者可能都为 None。
这个函数尽量不抛异常,调用方可按需处理 None。
"""
url_candidate = None
local_path = None
async def _get_str_value(coro_or_val):
"""辅助函数:统一处理协程或普通值"""
try:
if hasattr(coro_or_val, "__await__"):
result = await coro_or_val
else:
result = coro_or_val
return result if isinstance(result, str) else None
except Exception:
return None
try:
# 1. 尝试异步方法
for method in ["convert_to_file_path", "get_file", "register_to_file_service"]:
if not hasattr(comp, method):
continue
try:
value = await _get_str_value(getattr(comp, method)())
if value:
if value.startswith("http"):
url_candidate = value
break
else:
local_path = value
except Exception:
continue
# 2. 尝试 get_file(True) 获取可直接访问的 URL
if not url_candidate and hasattr(comp, "get_file"):
try:
value = await _get_str_value(comp.get_file(True))
if value and value.startswith("http"):
url_candidate = value
except Exception:
pass
# 3. 回退到同步属性
if not url_candidate and not local_path:
for attr in ("file", "url", "path", "src", "source"):
try:
value = getattr(comp, attr, None)
if value and isinstance(value, str):
if value.startswith("http"):
url_candidate = value
break
else:
local_path = value
break
except Exception:
continue
except Exception:
pass
return url_candidate, local_path
def summarize_component_for_log(comp: Any) -> Dict[str, Any]:
"""生成适合日志的组件属性字典(尽量不抛异常)。"""
attrs = {}
for a in ("file", "url", "path", "src", "source", "name"):
try:
v = getattr(comp, a, None)
if v is not None:
attrs[a] = v
except Exception:
continue
return attrs
async def upload_local_with_retries(
api: Any,
local_path: str,
preferred_name: Optional[str],
folder_id: Optional[str],
) -> Optional[str]:
"""尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。"""
try:
res = await api.upload_file(local_path, preferred_name, folder_id)
if isinstance(res, dict):
fid = res.get("id") or (res.get("raw") or {}).get("createdFile", {}).get(
"id"
)
if fid:
return str(fid)
except Exception:
# 上传失败,直接返回 None让上层处理错误
return None
return None

View File

@@ -15,12 +15,13 @@ class QQOfficialWebhook:
self.appid = config["appid"]
self.secret = config["secret"]
self.port = config.get("port", 6196)
self.is_sandbox = config.get("is_sandbox", False)
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
if isinstance(self.port, str):
self.port = int(self.port)
self.http: BotHttp = BotHttp(timeout=300)
self.http: BotHttp = BotHttp(timeout=300, is_sandbox=self.is_sandbox)
self.api: BotAPI = BotAPI(http=self.http)
self.token = Token(self.appid, self.secret)

View File

@@ -17,7 +17,14 @@ from astrbot.api.platform import (
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.api.message_components import Plain, Image, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
At,
File,
Record,
Reply,
)
from xml.etree import ElementTree as ET
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
)
self.token = self.config.get("satori_token", "")
self.endpoint = self.config.get(
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
)
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
self.metadata = PlatformMetadata(
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
)
self.ws: Optional[ClientConnection] = None
self.session: Optional[ClientSession] = None
self.sequence = 0
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
return self.metadata
def _is_websocket_closed(self, ws) -> bool:
"""检查WebSocket连接是否已关闭"""
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
abm.self_id = login.get("user", {}).get("id", "")
content = message.get("content", "")
abm.message = await self.parse_satori_elements(content)
# 消息链
abm.message = []
content = message.get("content", "")
quote = message.get("quote")
content_for_parsing = content # 副本
# 提取<quote>标签
if "<quote" in content:
try:
quote_info = await self._extract_quote_element(content)
if quote_info:
quote = quote_info["quote"]
content_for_parsing = quote_info["content_without_quote"]
except Exception as e:
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
if quote:
# 引用消息
quote_abm = await self._convert_quote_message(quote)
if quote_abm:
sender_id = quote_abm.sender.user_id
if isinstance(sender_id, str) and sender_id.isdigit():
sender_id = int(sender_id)
elif not isinstance(sender_id, int):
sender_id = 0 # 默认值
reply_component = Reply(
id=quote_abm.message_id,
chain=quote_abm.message,
sender_id=quote_abm.sender.user_id,
sender_nickname=quote_abm.sender.nickname,
time=quote_abm.timestamp,
message_str=quote_abm.message_str,
text=quote_abm.message_str,
qq=sender_id,
)
abm.message.append(reply_component)
# 解析消息内容
content_elements = await self.parse_satori_elements(content_for_parsing)
abm.message.extend(content_elements)
# parse message_str
abm.message_str = ""
for comp in abm.message:
for comp in content_elements:
if isinstance(comp, Plain):
abm.message_str += comp.text
@@ -333,6 +386,189 @@ class SatoriPlatformAdapter(Platform):
logger.error(f"转换 Satori 消息失败: {e}")
return None
def _extract_namespace_prefixes(self, content: str) -> set:
"""提取XML内容中的命名空间前缀"""
prefixes = set()
# 查找所有标签
i = 0
while i < len(content):
# 查找开始标签
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 1 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content and "xmlns:" not in tag_content:
# 分割标签名
parts = tag_content.split()
if parts:
tag_name = parts[0]
if ":" in tag_name:
prefix = tag_name.split(":")[0]
# 确保是有效的命名空间前缀
if (
prefix.isalnum()
or prefix.replace("_", "").isalnum()
):
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
# 查找结束标签
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 2 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content:
prefix = tag_content.split(":")[0]
# 确保是有效的命名空间前缀
if prefix.isalnum() or prefix.replace("_", "").isalnum():
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
else:
i += 1
return prefixes
async def _extract_quote_element(self, content: str) -> Optional[dict]:
"""提取<quote>标签信息"""
try:
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
# 查找<quote>标签
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
if tag_name.lower() == "quote":
quote_element = elem
break
if quote_element is not None:
# 提取quote标签的属性
quote_id = quote_element.get("id", "")
# 提取<quote>标签内部的内容
inner_content = ""
if quote_element.text:
inner_content += quote_element.text
for child in quote_element:
inner_content += ET.tostring(
child, encoding="unicode", method="xml"
)
if child.tail:
inner_content += child.tail
# 构造移除了<quote>标签的内容
content_without_quote = content.replace(
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
)
return {
"quote": {"id": quote_id, "content": inner_content},
"content_without_quote": content_without_quote,
}
return None
except ET.ParseError as e:
logger.warning(f"XML解析失败使用正则提取: {e}")
return await self._extract_quote_with_regex(content)
except Exception as e:
logger.error(f"提取<quote>标签时发生错误: {e}")
return None
async def _extract_quote_with_regex(self, content: str) -> Optional[dict]:
"""使用正则表达式提取quote标签信息"""
import re
quote_pattern = r"<quote\s+([^>]*)>(.*?)</quote>"
match = re.search(quote_pattern, content, re.DOTALL)
if not match:
return None
attrs_str = match.group(1)
inner_content = match.group(2)
id_match = re.search(r'id\s*=\s*["\']([^"\']*)["\']', attrs_str)
quote_id = id_match.group(1) if id_match else ""
content_without_quote = content.replace(match.group(0), "")
content_without_quote = content_without_quote.strip()
return {
"quote": {"id": quote_id, "content": inner_content},
"content_without_quote": content_without_quote,
}
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
"""转换引用消息"""
try:
quote_abm = AstrBotMessage()
quote_abm.message_id = quote.get("id", "")
# 解析引用消息的发送者
quote_author = quote.get("author", {})
if quote_author:
quote_abm.sender = MessageMember(
user_id=quote_author.get("id", ""),
nickname=quote_author.get("nick", quote_author.get("name", "")),
)
else:
# 如果没有作者信息,使用默认值
quote_abm.sender = MessageMember(
user_id=quote.get("user_id", ""),
nickname="内容",
)
# 解析引用消息内容
quote_content = quote.get("content", "")
quote_abm.message = await self.parse_satori_elements(quote_content)
quote_abm.message_str = ""
for comp in quote_abm.message:
if isinstance(comp, Plain):
quote_abm.message_str += comp.text
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
# 如果没有任何内容,使用默认文本
if not quote_abm.message_str.strip():
quote_abm.message_str = "[引用消息]"
return quote_abm
except Exception as e:
logger.error(f"转换引用消息失败: {e}")
return None
async def parse_satori_elements(self, content: str) -> list:
"""解析 Satori 消息元素"""
elements = []
@@ -341,12 +577,35 @@ class SatoriPlatformAdapter(Platform):
return elements
try:
wrapped_content = f"<root>{content}</root>"
root = ET.fromstring(wrapped_content)
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
logger.warning(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
# 如果解析失败,将整个内容当作纯文本
if content.strip():
elements.append(Plain(text=content))
except Exception as e:
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
raise e
# 如果没有解析到任何元素,将整个内容当作纯文本
@@ -361,7 +620,12 @@ class SatoriPlatformAdapter(Platform):
elements.append(Plain(text=node.text))
for child in node:
tag_name = child.tag.lower()
# 获取标签名,去除命名空间前缀
tag_name = child.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
tag_name = tag_name.lower()
attrs = child.attrib
if tag_name == "at":
@@ -372,31 +636,59 @@ class SatoriPlatformAdapter(Platform):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:image/"):
src = src.split(",")[1]
elements.append(Image.fromBase64(src))
elif src.startswith("http"):
elements.append(Image.fromURL(src))
else:
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
elements.append(Image(file=src))
elif tag_name == "file":
src = attrs.get("src", "")
name = attrs.get("name", "文件")
if src:
elements.append(File(file=src, name=name))
elements.append(File(name=name, file=src))
elif tag_name in ("audio", "record"):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:audio/"):
src = src.split(",")[1]
elements.append(Record.fromBase64(src))
elif src.startswith("http"):
elements.append(Record.fromURL(src))
elements.append(Record(file=src))
elif tag_name == "quote":
# quote标签已经被特殊处理
pass
elif tag_name == "face":
face_id = attrs.get("id", "")
face_name = attrs.get("name", "")
face_type = attrs.get("type", "")
if face_name:
elements.append(Plain(text=f"[表情:{face_name}]"))
elif face_id and face_type:
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
elif face_id:
elements.append(Plain(text=f"[表情ID:{face_id}]"))
else:
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
elements.append(Plain(text="[表情]"))
elif tag_name == "ark":
# 作为纯文本添加到消息链中
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[ARK卡片]"))
elif tag_name == "json":
# JSON标签 视为ARK卡片消息
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[JSON卡片]"))
else:
# 未知标签,递归处理其内容

View File

@@ -2,7 +2,18 @@ from typing import TYPE_CHECKING
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
At,
File,
Record,
Video,
Reply,
Forward,
Node,
Nodes,
)
if TYPE_CHECKING:
from .satori_adapter import SatoriPlatformAdapter
@@ -17,6 +28,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
session_id: str,
adapter: "SatoriPlatformAdapter",
):
# 更新平台元数据
if adapter and hasattr(adapter, "logins") and adapter.logins:
current_login = adapter.logins[0]
platform_name = current_login.get("platform", "satori")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
if not platform_meta.id and user_id:
platform_meta.id = f"{platform_name}({user_id})"
super().__init__(message_str, message_obj, platform_meta, session_id)
self.adapter = adapter
self.platform = None
@@ -39,44 +59,24 @@ class SatoriPlatformEvent(AstrMessageEvent):
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
component_content = await cls._convert_component_to_satori_static(
component
)
if component_content:
content_parts.append(component_content)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
# 特殊处理 Node 和 Nodes 组件
if isinstance(component, Node):
# 单个转发节点
node_content = await cls._convert_node_to_satori_static(component)
if node_content:
content_parts.append(node_content)
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Nodes):
# 合并转发消息
node_content = await cls._convert_nodes_to_satori_static(component)
if node_content:
content_parts.append(node_content)
content = "".join(content_parts)
channel_id = session_id
@@ -118,44 +118,22 @@ class SatoriPlatformEvent(AstrMessageEvent):
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
component_content = await self._convert_component_to_satori(component)
if component_content:
content_parts.append(component_content)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
# 特殊处理 Node 和 Nodes 组件
if isinstance(component, Node):
# 单个转发节点
node_content = await self._convert_node_to_satori(component)
if node_content:
content_parts.append(node_content)
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Nodes):
# 合并转发消息
node_content = await self._convert_nodes_to_satori(component)
if node_content:
content_parts.append(node_content)
content = "".join(content_parts)
channel_id = self.session_id
@@ -219,3 +197,227 @@ class SatoriPlatformEvent(AstrMessageEvent):
logger.error(f"Satori 流式消息发送异常: {e}")
return await super().send_streaming(generator, use_fallback)
async def _convert_component_to_satori(self, component) -> str:
"""将单个消息组件转换为 Satori 格式"""
try:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
return text
elif isinstance(component, At):
if component.qq:
return f'<at id="{component.qq}"/>'
elif component.name:
return f'<at name="{component.name}"/>'
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
return (
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
return f'<reply id="{component.id}"/>'
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
return f'<video src="{video_path_url}"/>'
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
elif isinstance(component, Forward):
return f'<message id="{component.id}" forward/>'
# 对于其他未处理的组件类型,返回空字符串
return ""
except Exception as e:
logger.error(f"转换消息组件失败: {e}")
return ""
async def _convert_node_to_satori(self, node: Node) -> str:
"""将单个转发节点转换为 Satori 格式"""
try:
content_parts = []
if node.content:
for content_component in node.content:
component_content = await self._convert_component_to_satori(
content_component
)
if component_content:
content_parts.append(component_content)
content = "".join(content_parts)
# 如果内容为空,添加默认内容
if not content.strip():
content = "[转发消息]"
# 构建 Satori 格式的转发节点
author_attrs = []
if node.uin:
author_attrs.append(f'id="{node.uin}"')
if node.name:
author_attrs.append(f'name="{node.name}"')
author_attr_str = " ".join(author_attrs)
return f"<message><author {author_attr_str}/>{content}</message>"
except Exception as e:
logger.error(f"转换转发节点失败: {e}")
return ""
@classmethod
async def _convert_component_to_satori_static(cls, component) -> str:
"""将单个消息组件转换为 Satori 格式"""
try:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
return text
elif isinstance(component, At):
if component.qq:
return f'<at id="{component.qq}"/>'
elif component.name:
return f'<at name="{component.name}"/>'
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
return (
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
return f'<reply id="{component.id}"/>'
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
return f'<video src="{video_path_url}"/>'
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
elif isinstance(component, Forward):
return f'<message id="{component.id}" forward/>'
# 对于其他未处理的组件类型,返回空字符串
return ""
except Exception as e:
logger.error(f"转换消息组件失败: {e}")
return ""
@classmethod
async def _convert_node_to_satori_static(cls, node: Node) -> str:
"""将单个转发节点转换为 Satori 格式"""
try:
content_parts = []
if node.content:
for content_component in node.content:
component_content = await cls._convert_component_to_satori_static(
content_component
)
if component_content:
content_parts.append(component_content)
content = "".join(content_parts)
# 如果内容为空,添加默认内容
if not content.strip():
content = "[转发消息]"
author_attrs = []
if node.uin:
author_attrs.append(f'id="{node.uin}"')
if node.name:
author_attrs.append(f'name="{node.name}"')
author_attr_str = " ".join(author_attrs)
return f"<message><author {author_attr_str}/>{content}</message>"
except Exception as e:
logger.error(f"转换转发节点失败: {e}")
return ""
async def _convert_nodes_to_satori(self, nodes: Nodes) -> str:
"""将多个转发节点转换为 Satori 格式的合并转发"""
try:
node_parts = []
for node in nodes.nodes:
node_content = await self._convert_node_to_satori(node)
if node_content:
node_parts.append(node_content)
if node_parts:
return f"<message forward>{''.join(node_parts)}</message>"
else:
return ""
except Exception as e:
logger.error(f"转换合并转发消息失败: {e}")
return ""
@classmethod
async def _convert_nodes_to_satori_static(cls, nodes: Nodes) -> str:
"""将多个转发节点转换为 Satori 格式的合并转发"""
try:
node_parts = []
for node in nodes.nodes:
node_content = await cls._convert_node_to_satori_static(node)
if node_content:
node_parts.append(node_content)
if node_parts:
return f"<message forward>{''.join(node_parts)}</message>"
else:
return ""
except Exception as e:
logger.error(f"转换合并转发消息失败: {e}")
return ""

View File

@@ -95,9 +95,8 @@ class TelegramPlatformAdapter(Platform):
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="telegram", description="telegram 适配器", id=self.config.get("id")
)
id_ = self.config.get("id") or "telegram"
return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_)
@override
async def run(self):
@@ -117,6 +116,10 @@ class TelegramPlatformAdapter(Platform):
)
self.scheduler.start()
if not self.application.updater:
logger.error("Telegram Updater is not initialized. Cannot start polling.")
return
queue = self.application.updater.start_polling()
logger.info("Telegram Platform Adapter is running.")
await queue
@@ -194,6 +197,11 @@ class TelegramPlatformAdapter(Platform):
return cmd_name, description
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.effective_chat:
logger.warning(
"Received a start command without an effective chat, skipping /start reply."
)
return
await context.bot.send_message(
chat_id=update.effective_chat.id, text=self.config["start_message"]
)
@@ -206,15 +214,20 @@ class TelegramPlatformAdapter(Platform):
async def convert_message(
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
) -> AstrBotMessage:
) -> AstrBotMessage | None:
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
@param update: Telegram 的 Update 对象。
@param context: Telegram 的 Context 对象。
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
if not update.message:
logger.warning("Received an update without a message.")
return None
message = AstrBotMessage()
message.session_id = str(update.message.chat.id)
# 获得是群聊还是私聊
if update.message.chat.type == ChatType.PRIVATE:
message.type = MessageType.FRIEND_MESSAGE
@@ -225,10 +238,13 @@ class TelegramPlatformAdapter(Platform):
# Topic Group
message.group_id += "#" + str(update.message.message_thread_id)
message.session_id = message.group_id
message.message_id = str(update.message.message_id)
_from_user = update.message.from_user
if not _from_user:
logger.warning("[Telegram] Received a message without a from_user.")
return None
message.sender = MessageMember(
str(update.message.from_user.id), update.message.from_user.username
str(_from_user.id), _from_user.username or "Unknown"
)
message.self_id = str(context.bot.username)
message.raw_message = update
@@ -247,22 +263,32 @@ class TelegramPlatformAdapter(Platform):
)
reply_abm = await self.convert_message(reply_update, context, False)
message.message.append(
Comp.Reply(
id=reply_abm.message_id,
chain=reply_abm.message,
sender_id=reply_abm.sender.user_id,
sender_nickname=reply_abm.sender.nickname,
time=reply_abm.timestamp,
message_str=reply_abm.message_str,
text=reply_abm.message_str,
qq=reply_abm.sender.user_id,
if reply_abm:
message.message.append(
Comp.Reply(
id=reply_abm.message_id,
chain=reply_abm.message,
sender_id=reply_abm.sender.user_id,
sender_nickname=reply_abm.sender.nickname,
time=reply_abm.timestamp,
message_str=reply_abm.message_str,
text=reply_abm.message_str,
qq=reply_abm.sender.user_id,
)
)
)
if update.message.text:
# 处理文本消息
plain_text = update.message.text
if (
message.type == MessageType.GROUP_MESSAGE
and update.message
and update.message.reply_to_message
and update.message.reply_to_message.from_user
and update.message.reply_to_message.from_user.id == context.bot.id
):
plain_text2 = f"/@{context.bot.username} " + plain_text
plain_text = plain_text2
# 群聊场景命令特殊处理
if plain_text.startswith("/"):
@@ -328,15 +354,25 @@ class TelegramPlatformAdapter(Platform):
elif update.message.document:
file = await update.message.document.get_file()
message.message = [
Comp.File(file=file.file_path, name=update.message.document.file_name),
]
file_name = update.message.document.file_name or uuid.uuid4().hex
file_path = file.file_path
if file_path is None:
logger.warning(
f"Telegram document file_path is None, cannot save the file {file_name}."
)
else:
message.message.append(Comp.File(file=file_path, name=file_name))
elif update.message.video:
file = await update.message.video.get_file()
message.message = [
Comp.Video(file=file.file_path, path=file.file_path),
]
file_name = update.message.video.file_name or uuid.uuid4().hex
file_path = file.file_path
if file_path is None:
logger.warning(
f"Telegram video file_path is None, cannot save the file {file_name}."
)
else:
message.message.append(Comp.Video(file=file_path, path=file.file_path))
return message

View File

@@ -16,6 +16,7 @@ from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from telegram import ReactionTypeEmoji, ReactionTypeCustomEmoji
class TelegramPlatformEvent(AstrMessageEvent):
@@ -135,6 +136,39 @@ class TelegramPlatformEvent(AstrMessageEvent):
await self.send_with_client(self.client, message, self.get_sender_id())
await super().send(message)
async def react(self, emoji: str | None, big: bool = False):
"""
给原消息添加 Telegram 反应:
- 普通 emoji传入 '👍''😂'
- 自定义表情:传入其 custom_emoji_id纯数字字符串
- 取消本机器人的反应:传入 None 或空字符串
"""
try:
# 解析 chat_id去掉超级群的 "#<thread_id>" 片段)
if self.get_message_type() == MessageType.GROUP_MESSAGE:
chat_id = (self.message_obj.group_id or "").split("#")[0]
else:
chat_id = self.get_sender_id()
message_id = int(self.message_obj.message_id)
# 组装 reaction 参数(必须是 ReactionType 的列表)
if not emoji: # 清空本 bot 的反应
reaction_param = [] # 空列表表示移除本 bot 的反应
elif emoji.isdigit(): # 自定义表情:传 custom_emoji_id
reaction_param = [ReactionTypeCustomEmoji(emoji)]
else: # 普通 emoji
reaction_param = [ReactionTypeEmoji(emoji)]
await self.client.set_message_reaction(
chat_id=chat_id,
message_id=message_id,
reaction=reaction_param, # 注意是列表
is_big=big, # 可选:大动画
)
except Exception as e:
logger.error(f"[Telegram] 添加反应失败: {e}")
async def send_streaming(self, generator, use_fallback: bool = False):
message_thread_id = None
@@ -218,7 +252,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
delta = ""
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id

View File

@@ -91,7 +91,6 @@ class WebChatAdapter(Platform):
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE

View File

@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
return PlatformMetadata(
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
)
@override

View File

@@ -0,0 +1,289 @@
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
"""对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2020 Tencent Inc.
"""
# ------------------------------------------------------------------------
import logging
import base64
import random
import hashlib
import time
import struct
from Crypto.Cipher import AES
import socket
import json
from . import ierror
"""
关于Crypto.Cipher模块ImportError: No module named 'Crypto'解决方案
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
下载后按照README中的“Installation”小节的提示进行pycrypto安装。
"""
class FormatException(Exception):
pass
def throw_exception(message, exception_class=FormatException):
"""my define raise exception function"""
raise exception_class(message)
class SHA1:
"""计算企业微信的消息签名接口"""
def getSHA1(self, token, timestamp, nonce, encrypt):
"""用SHA1算法生成安全签名
@param token: 票据
@param timestamp: 时间戳
@param encrypt: 密文
@param nonce: 随机字符串
@return: 安全签名
"""
try:
# 确保所有输入都是字符串类型
if isinstance(encrypt, bytes):
encrypt = encrypt.decode("utf-8")
sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist).encode("utf-8"))
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
class JsonParse:
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
# json消息模板
AES_TEXT_RESPONSE_TEMPLATE = """{
"encrypt": "%(msg_encrypt)s",
"msgsignature": "%(msg_signaturet)s",
"timestamp": "%(timestamp)s",
"nonce": "%(nonce)s"
}"""
def extract(self, jsontext):
"""提取出json数据包中的加密消息
@param jsontext: 待提取的json字符串
@return: 提取出的加密消息字符串
"""
try:
json_dict = json.loads(jsontext)
return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"]
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_ParseJson_Error, None
def generate(self, encrypt, signature, timestamp, nonce):
"""生成json消息
@param encrypt: 加密后的消息密文
@param signature: 安全签名
@param timestamp: 时间戳
@param nonce: 随机字符串
@return: 生成的json字符串
"""
resp_dict = {
"msg_encrypt": encrypt,
"msg_signaturet": signature,
"timestamp": timestamp,
"nonce": nonce,
}
resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
return resp_json
class PKCS7Encoder:
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
"""对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文(bytes类型)
@return: 补齐明文字符串(bytes类型)
"""
text_length = len(text)
# 计算需要填充的位数
amount_to_pad = self.block_size - (text_length % self.block_size)
if amount_to_pad == 0:
amount_to_pad = self.block_size
# 获得补位所用的字符
pad = bytes([amount_to_pad])
# 确保text是bytes类型
if isinstance(text, str):
text = text.encode("utf-8")
return text + pad * amount_to_pad
def decode(self, decrypted):
"""删除解密后明文的补位字符
@param decrypted: 解密后的明文
@return: 删除补位字符后的明文
"""
pad = ord(decrypted[-1])
if pad < 1 or pad > 32:
pad = 0
return decrypted[:-pad]
class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
self.mode = AES.MODE_CBC
def encrypt(self, text, receiveid):
"""对明文进行加密
@param text: 需要加密的明文
@return: 加密得到的字符串
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = (
self.get_random_str()
+ struct.pack("I", socket.htonl(len(text)))
+ text
+ receiveid.encode()
)
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
text = pkcs7.encode(text)
# 加密
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
try:
ciphertext = cryptor.encrypt(text)
# 使用BASE64对加密后的字符串进行编码
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
except Exception as e:
logger = logging.getLogger("astrbot")
logger.error(e)
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
def decrypt(self, text, receiveid):
"""对解密后的明文进行补位删除
@param text: 密文
@return: 删除填充补位后的明文
"""
try:
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
# 使用BASE64对密文进行解码然后AES-CBC解密
plain_text = cryptor.decrypt(base64.b64decode(text))
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
try:
pad = plain_text[-1]
# 去掉补位字符串
# pkcs7 = PKCS7Encoder()
# plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
json_len = socket.ntohl(struct.unpack("I", content[:4])[0])
json_content = content[4 : json_len + 4].decode("utf-8")
from_receiveid = content[json_len + 4 :].decode("utf-8")
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_IllegalBuffer, None
if from_receiveid != receiveid:
print("receiveid not match", receiveid, from_receiveid)
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
return 0, json_content
def get_random_str(self):
"""随机生成16位字符串
@return: 16位字符串
"""
return str(random.randint(1000000000000000, 9999999999999999)).encode()
class WXBizJsonMsgCrypt(object):
# 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey + "=")
assert len(self.key) == 32
except Exception as e:
throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId
# 验证URL
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sEchoStr: 随机串对应URL参数的echostr
# @param sReplyEchoStr: 解密之后的echostr当return返回0时有效
# @return成功0失败返回对应的错误码
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
return ret, sReplyEchoStr
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
# 将企业回复用户的消息加密打包
# @param sReplyMsg: 企业号待回复用户的消息json格式的字符串
# @param sTimeStamp: 时间戳可以自己生成也可以用URL参数的timestamp,如为None则自动用当前时间
# @param sNonce: 随机串可以自己生成也可以用URL参数的nonce
# sEncryptMsg: 加密后的可以直接回复用户的密文包括msg_signature, timestamp, nonce, encrypt的json格式的字符串,
# return成功0sEncryptMsg,失败返回对应的错误码None
pc = Prpcrypt(self.key)
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
encrypt = encrypt.decode("utf-8") # type: ignore
if ret != 0:
return ret, None
if timestamp is None:
timestamp = str(int(time.time()))
# 生成安全签名
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
if ret != 0:
return ret, None
jsonParse = JsonParse()
return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce)
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
# 检验消息的真实性,并且获取解密后的明文
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sPostData: 密文对应POST请求的数据
# json_content: 解密后的原文当return返回0时有效
# @return: 成功0失败返回对应的错误码
# 验证安全签名
jsonParse = JsonParse()
ret, encrypt = jsonParse.extract(sPostData)
if ret != 0:
return ret, None
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
print("signature not match")
print(signature)
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId)
return ret, json_content

View File

@@ -0,0 +1,17 @@
"""
企业微信智能机器人平台适配器包
"""
from .wecomai_adapter import WecomAIBotAdapter
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_server import WecomAIBotServer
from .wecomai_utils import WecomAIBotConstants
__all__ = [
"WecomAIBotAdapter",
"WecomAIBotAPIClient",
"WecomAIBotMessageEvent",
"WecomAIBotServer",
"WecomAIBotConstants",
]

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#########################################################################
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
WXBizMsgCrypt_ParseJson_Error = -40002
WXBizMsgCrypt_ComputeSignature_Error = -40003
WXBizMsgCrypt_IllegalAesKey = -40004
WXBizMsgCrypt_ValidateCorpid_Error = -40005
WXBizMsgCrypt_EncryptAES_Error = -40006
WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnJson_Error = -40011

View File

@@ -0,0 +1,445 @@
"""
企业微信智能机器人平台适配器
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调
参考webchat_adapter.py的队列机制实现异步消息处理和流式响应
"""
import time
import asyncio
import uuid
import hashlib
import base64
from typing import Awaitable, Any, Dict, Optional, Callable
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, At, Image
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .wecomai_api import (
WecomAIBotAPIClient,
WecomAIBotMessageParser,
WecomAIBotStreamMessageBuilder,
)
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_server import WecomAIBotServer
from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr
from .wecomai_utils import (
WecomAIBotConstants,
format_session_id,
generate_random_string,
process_encrypted_image,
)
class WecomAIQueueListener:
"""企业微信智能机器人队列监听器参考webchat的QueueListener设计"""
def __init__(
self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]]
) -> None:
self.queue_mgr = queue_mgr
self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, session_id: str):
"""监听特定会话的队列"""
queue = self.queue_mgr.get_or_create_queue(session_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
break
async def run(self):
"""监控新会话队列并启动监听器"""
monitored_sessions = set()
while True:
# 检查新会话
current_sessions = set(self.queue_mgr.queues.keys())
new_sessions = current_sessions - monitored_sessions
# 为新会话启动监听器
for session_id in new_sessions:
task = asyncio.create_task(self.listen_to_queue(session_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_sessions.add(session_id)
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
# 清理已不存在的会话
removed_sessions = monitored_sessions - current_sessions
monitored_sessions -= removed_sessions
# 清理过期的待处理响应
self.queue_mgr.cleanup_expired_responses()
await asyncio.sleep(1) # 每秒检查一次新会话
@register_platform_adapter(
"wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息"
)
class WecomAIBotAdapter(Platform):
"""企业微信智能机器人适配器"""
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
# 初始化配置参数
self.token = self.config["token"]
self.encoding_aes_key = self.config["encoding_aes_key"]
self.port = int(self.config["port"])
self.host = self.config.get("callback_server_host", "0.0.0.0")
self.bot_name = self.config.get("wecom_ai_bot_name", "")
self.initial_respond_text = self.config.get(
"wecomaibot_init_respond_text", "💭 思考中..."
)
self.friend_message_welcome_text = self.config.get(
"wecomaibot_friend_message_welcome_text", ""
)
# 平台元数据
self.metadata = PlatformMetadata(
name="wecom_ai_bot",
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
id=self.config.get("id", "wecom_ai_bot"),
)
# 初始化 API 客户端
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
# 初始化 HTTP 服务器
self.server = WecomAIBotServer(
host=self.host,
port=self.port,
api_client=self.api_client,
message_handler=self._process_message,
)
# 事件循环和关闭信号
self.shutdown_event = asyncio.Event()
# 队列监听器
self.queue_listener = WecomAIQueueListener(
wecomai_queue_mgr, self._handle_queued_message
)
async def _handle_queued_message(self, data: dict):
"""处理队列中的消息类似webchat的callback"""
try:
abm = await self.convert_message(data)
await self.handle_msg(abm)
except Exception as e:
logger.error(f"处理队列消息时发生异常: {e}")
async def _process_message(
self, message_data: Dict[str, Any], callback_params: Dict[str, str]
) -> Optional[str]:
"""处理接收到的消息
Args:
message_data: 解密后的消息数据
callback_params: 回调参数 (nonce, timestamp)
Returns:
加密后的响应消息,无需响应时返回 None
"""
msgtype = message_data.get("msgtype")
if not msgtype:
logger.warning(f"消息类型未知,忽略: {message_data}")
return None
session_id = self._extract_session_id(message_data)
if msgtype in ("text", "image", "mixed"):
# user sent a text / image / mixed message
try:
# create a brand-new unique stream_id for this message session
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._enqueue_message(
message_data, callback_params, stream_id, session_id
)
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id, self.initial_respond_text, False
)
return await self.api_client.encrypt_message(
resp, callback_params["nonce"], callback_params["timestamp"]
)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
return None
elif msgtype == "stream":
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
if not wecomai_queue_mgr.has_back_queue(stream_id):
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
# 返回结束标志,告诉微信服务器流已结束
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id, "", True
)
resp = await self.api_client.encrypt_message(
end_message,
callback_params["nonce"],
callback_params["timestamp"],
)
return resp
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if queue.empty():
logger.debug(
f"No new messages in back queue for stream_id: {stream_id}"
)
return None
# aggregate all delta chains in the back queue
latest_plain_content = ""
image_base64 = []
finish = False
while not queue.empty():
msg = await queue.get()
if msg["type"] == "plain":
latest_plain_content = msg["data"] or ""
elif msg["type"] == "image":
image_base64.append(msg["image_data"])
elif msg["type"] == "end":
# stream end
finish = True
wecomai_queue_mgr.remove_queues(stream_id)
break
else:
pass
logger.debug(
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}"
)
if latest_plain_content or image_base64:
msg_items = []
if finish and image_base64:
for img_b64 in image_base64:
# get md5 of image
img_data = base64.b64decode(img_b64)
img_md5 = hashlib.md5(img_data).hexdigest()
msg_items.append(
{
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
"image": {"base64": img_b64, "md5": img_md5},
}
)
image_base64 = []
plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream(
stream_id, latest_plain_content, msg_items, finish
)
encrypted_message = await self.api_client.encrypt_message(
plain_message,
callback_params["nonce"],
callback_params["timestamp"],
)
if encrypted_message:
logger.debug(
f"Stream message sent successfully, stream_id: {stream_id}"
)
else:
logger.error("消息加密失败")
return encrypted_message
return None
elif msgtype == "event":
event = message_data.get("event")
if event == "enter_chat" and self.friend_message_welcome_text:
# 用户进入会话,发送欢迎消息
try:
resp = WecomAIBotStreamMessageBuilder.make_text(
self.friend_message_welcome_text
)
return await self.api_client.encrypt_message(
resp,
callback_params["nonce"],
callback_params["timestamp"],
)
except Exception as e:
logger.error("处理欢迎消息时发生异常: %s", e)
return None
pass
def _extract_session_id(self, message_data: Dict[str, Any]) -> str:
"""从消息数据中提取会话ID"""
user_id = message_data.get("from", {}).get("userid", "default_user")
return format_session_id("wecomai", user_id)
async def _enqueue_message(
self,
message_data: Dict[str, Any],
callback_params: Dict[str, str],
stream_id: str,
session_id: str,
):
"""将消息放入队列进行异步处理"""
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
message_payload = {
"message_data": message_data,
"callback_params": callback_params,
"session_id": session_id,
"stream_id": stream_id,
}
await input_queue.put(message_payload)
logger.debug(f"[WecomAI] 消息已入队: {stream_id}")
async def convert_message(self, payload: dict) -> AstrBotMessage:
"""转换队列中的消息数据为AstrBotMessage类似webchat的convert_message"""
message_data = payload["message_data"]
session_id = payload["session_id"]
# callback_params = payload["callback_params"] # 保留但暂时不使用
# 解析消息内容
msgtype = message_data.get("msgtype")
content = ""
image_base64 = []
_img_url_to_process = []
msg_items = []
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
content = WecomAIBotMessageParser.parse_text_message(message_data)
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
_img_url_to_process.append(
WecomAIBotMessageParser.parse_image_message(message_data)
)
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
# 提取混合消息中的文本内容
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
text_parts = []
for item in msg_items or []:
if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT:
text_content = item.get("text", {}).get("content", "")
if text_content:
text_parts.append(text_content)
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
image_url = item.get("image", {}).get("url", "")
if image_url:
_img_url_to_process.append(image_url)
content = " ".join(text_parts) if text_parts else ""
else:
content = f"[{msgtype}消息]"
# 并行处理图片下载和解密
if _img_url_to_process:
tasks = [
process_encrypted_image(url, self.encoding_aes_key)
for url in _img_url_to_process
]
results = await asyncio.gather(*tasks)
for success, result in results:
if success:
image_base64.append(result)
else:
logger.error(f"处理加密图片失败: {result}")
# 构建 AstrBotMessage
abm = AstrBotMessage()
abm.self_id = self.bot_name
abm.message_str = content or "[未知消息]"
abm.message_id = str(uuid.uuid4())
abm.timestamp = int(time.time())
abm.raw_message = payload
# 发送者信息
abm.sender = MessageMember(
user_id=message_data.get("from", {}).get("userid", "unknown"),
nickname=message_data.get("from", {}).get("userid", "unknown"),
)
# 消息类型
abm.type = (
MessageType.GROUP_MESSAGE
if message_data.get("chattype") == "group"
else MessageType.FRIEND_MESSAGE
)
abm.session_id = session_id
# 消息内容
abm.message = []
# 处理 At
if self.bot_name and f"@{self.bot_name}" in abm.message_str:
abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip()
abm.message.append(At(qq=self.bot_name, name=self.bot_name))
abm.message.append(Plain(abm.message_str))
if image_base64:
for img_b64 in image_base64:
abm.message.append(Image.fromBase64(img_b64))
logger.debug(f"WecomAIAdapter: {abm.message}")
return abm
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
"""通过会话发送消息"""
# 企业微信智能机器人主要通过回调响应,这里记录日志
logger.info("会话发送消息: %s -> %s", session.session_id, message_chain)
await super().send_by_session(session, message_chain)
def run(self) -> Awaitable[Any]:
"""运行适配器同时启动HTTP服务器和队列监听器"""
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
async def run_both():
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.queue_listener.run(),
)
return run_both()
async def terminate(self):
"""终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.server.shutdown()
def meta(self) -> PlatformMetadata:
"""获取平台元数据"""
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
"""处理消息,创建消息事件并提交到事件队列"""
try:
message_event = WecomAIBotMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
api_client=self.api_client,
)
self.commit_event(message_event)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
def get_client(self) -> WecomAIBotAPIClient:
"""获取 API 客户端"""
return self.api_client
def get_server(self) -> WecomAIBotServer:
"""获取 HTTP 服务器实例"""
return self.server

View File

@@ -0,0 +1,378 @@
"""
企业微信智能机器人 API 客户端
处理消息加密解密、API 调用等
"""
import json
import base64
import hashlib
from typing import Dict, Any, Optional, Tuple, Union
from Crypto.Cipher import AES
import aiohttp
from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt
from .wecomai_utils import WecomAIBotConstants
from astrbot import logger
class WecomAIBotAPIClient:
"""企业微信智能机器人 API 客户端"""
def __init__(self, token: str, encoding_aes_key: str):
"""初始化 API 客户端
Args:
token: 企业微信机器人 Token
encoding_aes_key: 消息加密密钥
"""
self.token = token
self.encoding_aes_key = encoding_aes_key
self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串
async def decrypt_message(
self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str
) -> Tuple[int, Optional[Dict[str, Any]]]:
"""解密企业微信消息
Args:
encrypted_data: 加密的消息数据
msg_signature: 消息签名
timestamp: 时间戳
nonce: 随机数
Returns:
(错误码, 解密后的消息数据字典)
"""
try:
ret, decrypted_msg = self.wxcpt.DecryptMsg(
encrypted_data, msg_signature, timestamp, nonce
)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"消息解密失败,错误码: {ret}")
return ret, None
# 解析 JSON
if decrypted_msg:
try:
message_data = json.loads(decrypted_msg)
logger.debug(f"解密成功,消息内容: {message_data}")
return WecomAIBotConstants.SUCCESS, message_data
except json.JSONDecodeError as e:
logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}")
return WecomAIBotConstants.PARSE_XML_ERROR, None
else:
logger.error("解密消息为空")
return WecomAIBotConstants.DECRYPT_ERROR, None
except Exception as e:
logger.error(f"解密过程发生异常: {e}")
return WecomAIBotConstants.DECRYPT_ERROR, None
async def encrypt_message(
self, plain_message: str, nonce: str, timestamp: str
) -> Optional[str]:
"""加密消息
Args:
plain_message: 明文消息
nonce: 随机数
timestamp: 时间戳
Returns:
加密后的消息,失败时返回 None
"""
try:
ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"消息加密失败,错误码: {ret}")
return None
logger.debug("消息加密成功")
return encrypted_msg
except Exception as e:
logger.error(f"加密过程发生异常: {e}")
return None
def verify_url(
self, msg_signature: str, timestamp: str, nonce: str, echostr: str
) -> str:
"""验证回调 URL
Args:
msg_signature: 消息签名
timestamp: 时间戳
nonce: 随机数
echostr: 验证字符串
Returns:
验证结果字符串
"""
try:
ret, echo_result = self.wxcpt.VerifyURL(
msg_signature, timestamp, nonce, echostr
)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"URL 验证失败,错误码: {ret}")
return "verify fail"
logger.info("URL 验证成功")
return echo_result if echo_result else "verify fail"
except Exception as e:
logger.error(f"URL 验证发生异常: {e}")
return "verify fail"
async def process_encrypted_image(
self, image_url: str, aes_key_base64: Optional[str] = None
) -> Tuple[bool, Union[bytes, str]]:
"""下载并解密加密图片
Args:
image_url: 加密图片的 URL
aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥
Returns:
(是否成功, 图片数据或错误信息)
"""
try:
# 下载图片
logger.info(f"开始下载加密图片: {image_url}")
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=15) as response:
if response.status != 200:
error_msg = f"图片下载失败,状态码: {response.status}"
logger.error(error_msg)
return False, error_msg
encrypted_data = await response.read()
logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节")
# 准备解密密钥
if aes_key_base64 is None:
aes_key_base64 = self.encoding_aes_key
if not aes_key_base64:
raise ValueError("AES 密钥不能为空")
# Base64 解码密钥
aes_key = base64.b64decode(
aes_key_base64 + "=" * (-len(aes_key_base64) % 4)
)
if len(aes_key) != 32:
raise ValueError("无效的 AES 密钥长度: 应为 32 字节")
iv = aes_key[:16] # 初始向量为密钥前 16 字节
# 解密图片数据
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = cipher.decrypt(encrypted_data)
# 去除 PKCS#7 填充
pad_len = decrypted_data[-1]
if pad_len > 32: # AES-256 块大小为 32 字节
raise ValueError("无效的填充长度 (大于32字节)")
decrypted_data = decrypted_data[:-pad_len]
logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节")
return True, decrypted_data
except aiohttp.ClientError as e:
error_msg = f"图片下载失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
except ValueError as e:
error_msg = f"参数错误: {str(e)}"
logger.error(error_msg)
return False, error_msg
except Exception as e:
error_msg = f"图片处理异常: {str(e)}"
logger.error(error_msg)
return False, error_msg
class WecomAIBotStreamMessageBuilder:
"""企业微信智能机器人流消息构建器"""
@staticmethod
def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str:
"""构建文本流消息
Args:
stream_id: 流 ID
content: 文本内容
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {"id": stream_id, "finish": finish, "content": content},
}
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_image_stream(
stream_id: str, image_data: bytes, finish: bool = False
) -> str:
"""构建图片流消息
Args:
stream_id: 流 ID
image_data: 图片二进制数据
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
image_md5 = hashlib.md5(image_data).hexdigest()
image_base64 = base64.b64encode(image_data).decode("utf-8")
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {
"id": stream_id,
"finish": finish,
"msg_item": [
{
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
"image": {"base64": image_base64, "md5": image_md5},
}
],
},
}
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_mixed_stream(
stream_id: str, content: str, msg_items: list, finish: bool = False
) -> str:
"""构建混合类型流消息
Args:
stream_id: 流 ID
content: 文本内容
msg_items: 消息项列表
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {"id": stream_id, "finish": finish, "msg_item": msg_items},
}
if content:
plain["stream"]["content"] = content
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_text(content: str) -> str:
"""构建文本消息
Args:
content: 文本内容
Returns:
JSON 格式的文本消息字符串
"""
plain = {"msgtype": "text", "text": {"content": content}}
return json.dumps(plain, ensure_ascii=False)
class WecomAIBotMessageParser:
"""企业微信智能机器人消息解析器"""
@staticmethod
def parse_text_message(data: Dict[str, Any]) -> Optional[str]:
"""解析文本消息
Args:
data: 消息数据
Returns:
文本内容,解析失败返回 None
"""
try:
return data.get("text", {}).get("content")
except (KeyError, TypeError):
logger.warning("文本消息解析失败")
return None
@staticmethod
def parse_image_message(data: Dict[str, Any]) -> Optional[str]:
"""解析图片消息
Args:
data: 消息数据
Returns:
图片 URL解析失败返回 None
"""
try:
return data.get("image", {}).get("url")
except (KeyError, TypeError):
logger.warning("图片消息解析失败")
return None
@staticmethod
def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""解析流消息
Args:
data: 消息数据
Returns:
流消息数据,解析失败返回 None
"""
try:
stream_data = data.get("stream", {})
return {
"id": stream_data.get("id"),
"finish": stream_data.get("finish"),
"content": stream_data.get("content"),
"msg_item": stream_data.get("msg_item", []),
}
except (KeyError, TypeError):
logger.warning("流消息解析失败")
return None
@staticmethod
def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]:
"""解析混合消息
Args:
data: 消息数据
Returns:
消息项列表,解析失败返回 None
"""
try:
return data.get("mixed", {}).get("msg_item", [])
except (KeyError, TypeError):
logger.warning("混合消息解析失败")
return None
@staticmethod
def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""解析事件消息
Args:
data: 消息数据
Returns:
事件数据,解析失败返回 None
"""
try:
return data.get("event", {})
except (KeyError, TypeError):
logger.warning("事件消息解析失败")
return None

View File

@@ -0,0 +1,149 @@
"""
企业微信智能机器人事件处理模块,处理消息事件的发送和接收
"""
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
Plain,
)
from astrbot.api import logger
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_queue_mgr import wecomai_queue_mgr
class WecomAIBotMessageEvent(AstrMessageEvent):
"""企业微信智能机器人消息事件"""
def __init__(
self,
message_str: str,
message_obj,
platform_meta,
session_id: str,
api_client: WecomAIBotAPIClient,
):
"""初始化消息事件
Args:
message_str: 消息字符串
message_obj: 消息对象
platform_meta: 平台元数据
session_id: 会话 ID
api_client: API 客户端
"""
super().__init__(message_str, message_obj, platform_meta, session_id)
self.api_client = api_client
@staticmethod
async def _send(
message_chain: MessageChain,
stream_id: str,
streaming: bool = False,
):
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if not message_chain:
await back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
}
)
return ""
data = ""
for comp in message_chain.chain:
if isinstance(comp, Plain):
data = comp.text
await back_queue.put(
{
"type": "plain",
"data": data,
"streaming": streaming,
"session_id": stream_id,
}
)
elif isinstance(comp, Image):
# 处理图片消息
try:
image_base64 = await comp.convert_to_base64()
if image_base64:
await back_queue.put(
{
"type": "image",
"image_data": image_base64,
"streaming": streaming,
"session_id": stream_id,
}
)
else:
logger.warning("图片数据为空,跳过")
except Exception as e:
logger.error("处理图片消息失败: %s", e)
else:
logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过")
return data
async def send(self, message: MessageChain):
"""发送消息"""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id)
await super().send(message)
async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息参考webchat的send_streaming设计"""
final_data = ""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
increment_plain = ""
async for chain in generator:
# 累积增量内容,并改写 Plain 段
chain.squash_plain()
for comp in chain.chain:
if isinstance(comp, Plain):
comp.text = increment_plain + comp.text
increment_plain = comp.text
break
if chain.type == "break" and final_data:
# 分割符
await back_queue.put(
{
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"session_id": self.session_id,
}
)
final_data = ""
continue
final_data += await WecomAIBotMessageEvent._send(
chain,
stream_id=stream_id,
streaming=True,
)
await back_queue.put(
{
"type": "complete", # complete means we return the final result
"data": final_data,
"streaming": True,
"session_id": self.session_id,
}
)
await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,148 @@
"""
企业微信智能机器人队列管理器
参考 webchat_queue_mgr.py为企业微信智能机器人实现队列机制
支持异步消息处理和流式响应
"""
import asyncio
from typing import Dict, Any, Optional
from astrbot.api import logger
class WecomAIQueueMgr:
"""企业微信智能机器人队列管理器"""
def __init__(self) -> None:
self.queues: Dict[str, asyncio.Queue] = {}
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
self.back_queues: Dict[str, asyncio.Queue] = {}
"""StreamID 到输出队列的映射 - 用于发送机器人响应"""
self.pending_responses: Dict[str, Dict[str, Any]] = {}
"""待处理的响应缓存,用于流式响应"""
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
"""获取或创建指定会话的输入队列
Args:
session_id: 会话ID
Returns:
输入队列实例
"""
if session_id not in self.queues:
self.queues[session_id] = asyncio.Queue()
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
return self.queues[session_id]
def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue:
"""获取或创建指定会话的输出队列
Args:
session_id: 会话ID
Returns:
输出队列实例
"""
if session_id not in self.back_queues:
self.back_queues[session_id] = asyncio.Queue()
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
return self.back_queues[session_id]
def remove_queues(self, session_id: str):
"""移除指定会话的所有队列
Args:
session_id: 会话ID
"""
if session_id in self.queues:
del self.queues[session_id]
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
if session_id in self.back_queues:
del self.back_queues[session_id]
logger.debug(f"[WecomAI] 移除输出队列: {session_id}")
if session_id in self.pending_responses:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
def has_queue(self, session_id: str) -> bool:
"""检查是否存在指定会话的队列
Args:
session_id: 会话ID
Returns:
是否存在队列
"""
return session_id in self.queues
def has_back_queue(self, session_id: str) -> bool:
"""检查是否存在指定会话的输出队列
Args:
session_id: 会话ID
Returns:
是否存在输出队列
"""
return session_id in self.back_queues
def set_pending_response(self, session_id: str, callback_params: Dict[str, str]):
"""设置待处理的响应参数
Args:
session_id: 会话ID
callback_params: 回调参数nonce, timestamp等
"""
self.pending_responses[session_id] = {
"callback_params": callback_params,
"timestamp": asyncio.get_event_loop().time(),
}
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]:
"""获取待处理的响应参数
Args:
session_id: 会话ID
Returns:
响应参数如果不存在则返回None
"""
return self.pending_responses.get(session_id)
def cleanup_expired_responses(self, max_age_seconds: int = 300):
"""清理过期的待处理响应
Args:
max_age_seconds: 最大存活时间(秒)
"""
current_time = asyncio.get_event_loop().time()
expired_sessions = []
for session_id, response_data in self.pending_responses.items():
if current_time - response_data["timestamp"] > max_age_seconds:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
def get_stats(self) -> Dict[str, int]:
"""获取队列统计信息
Returns:
统计信息字典
"""
return {
"input_queues": len(self.queues),
"output_queues": len(self.back_queues),
"pending_responses": len(self.pending_responses),
}
# 全局队列管理器实例
wecomai_queue_mgr = WecomAIQueueMgr()

View File

@@ -0,0 +1,166 @@
"""
企业微信智能机器人 HTTP 服务器
处理企业微信智能机器人的 HTTP 回调请求
"""
import asyncio
from typing import Dict, Any, Optional, Callable
import quart
from astrbot.api import logger
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_utils import WecomAIBotConstants
class WecomAIBotServer:
"""企业微信智能机器人 HTTP 服务器"""
def __init__(
self,
host: str,
port: int,
api_client: WecomAIBotAPIClient,
message_handler: Optional[
Callable[[Dict[str, Any], Dict[str, str]], Any]
] = None,
):
"""初始化服务器
Args:
host: 监听地址
port: 监听端口
api_client: API客户端实例
message_handler: 消息处理回调函数
"""
self.host = host
self.port = port
self.api_client = api_client
self.message_handler = message_handler
self.app = quart.Quart(__name__)
self._setup_routes()
self.shutdown_event = asyncio.Event()
def _setup_routes(self):
"""设置 Quart 路由"""
# 使用 Quart 的 add_url_rule 方法添加路由
self.app.add_url_rule(
"/webhook/wecom-ai-bot",
view_func=self.verify_url,
methods=["GET"],
)
self.app.add_url_rule(
"/webhook/wecom-ai-bot",
view_func=self.handle_message,
methods=["POST"],
)
async def verify_url(self):
"""验证回调 URL"""
args = quart.request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
echostr = args.get("echostr")
if not all([msg_signature, timestamp, nonce, echostr]):
logger.error("URL 验证参数缺失")
return "verify fail", 400
# 类型检查确保不为 None
assert msg_signature is not None
assert timestamp is not None
assert nonce is not None
assert echostr is not None
logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
return result, 200, {"Content-Type": "text/plain"}
async def handle_message(self):
"""处理消息回调"""
args = quart.request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
if not all([msg_signature, timestamp, nonce]):
logger.error("消息回调参数缺失")
return "缺少必要参数", 400
# 类型检查确保不为 None
assert msg_signature is not None
assert timestamp is not None
assert nonce is not None
logger.debug(
f"收到消息回调msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
)
try:
# 获取请求体
post_data = await quart.request.get_data()
# 确保 post_data 是 bytes 类型
if isinstance(post_data, str):
post_data = post_data.encode("utf-8")
# 解密消息
ret_code, message_data = await self.api_client.decrypt_message(
post_data, msg_signature, timestamp, nonce
)
if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
logger.error("消息解密失败,错误码: %d", ret_code)
return "消息解密失败", 400
# 调用消息处理器
response = None
if self.message_handler:
try:
response = await self.message_handler(
message_data, {"nonce": nonce, "timestamp": timestamp}
)
except Exception as e:
logger.error("消息处理器执行异常: %s", e)
return "消息处理异常", 500
if response:
return response, 200, {"Content-Type": "text/plain"}
else:
return "success", 200, {"Content-Type": "text/plain"}
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
return "内部服务器错误", 500
async def start_server(self):
"""启动服务器"""
logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
try:
await self.app.run_task(
host=self.host,
port=self.port,
shutdown_trigger=self.shutdown_trigger,
)
except Exception as e:
logger.error("服务器运行异常: %s", e)
raise
async def shutdown_trigger(self):
"""关闭触发器"""
await self.shutdown_event.wait()
async def shutdown(self):
"""关闭服务器"""
logger.info("企业微信智能机器人服务器正在关闭...")
self.shutdown_event.set()
def get_app(self):
"""获取 Quart 应用实例"""
return self.app

View File

@@ -0,0 +1,199 @@
"""
企业微信智能机器人工具模块
提供常量定义、工具函数和辅助方法
"""
import string
import random
import hashlib
import base64
import aiohttp
import asyncio
from Crypto.Cipher import AES
from typing import Any, Tuple
from astrbot.api import logger
# 常量定义
class WecomAIBotConstants:
"""企业微信智能机器人常量"""
# 消息类型
MSG_TYPE_TEXT = "text"
MSG_TYPE_IMAGE = "image"
MSG_TYPE_MIXED = "mixed"
MSG_TYPE_STREAM = "stream"
MSG_TYPE_EVENT = "event"
# 流消息状态
STREAM_CONTINUE = False
STREAM_FINISH = True
# 错误码
SUCCESS = 0
DECRYPT_ERROR = -40001
VALIDATE_SIGNATURE_ERROR = -40002
PARSE_XML_ERROR = -40003
COMPUTE_SIGNATURE_ERROR = -40004
ILLEGAL_AES_KEY = -40005
VALIDATE_APPID_ERROR = -40006
ENCRYPT_AES_ERROR = -40007
ILLEGAL_BUFFER = -40008
def generate_random_string(length: int = 10) -> str:
"""生成随机字符串
Args:
length: 字符串长度,默认为 10
Returns:
随机字符串
"""
letters = string.ascii_letters + string.digits
return "".join(random.choice(letters) for _ in range(length))
def calculate_image_md5(image_data: bytes) -> str:
"""计算图片数据的 MD5 值
Args:
image_data: 图片二进制数据
Returns:
MD5 哈希值(十六进制字符串)
"""
return hashlib.md5(image_data).hexdigest()
def encode_image_base64(image_data: bytes) -> str:
"""将图片数据编码为 Base64
Args:
image_data: 图片二进制数据
Returns:
Base64 编码的字符串
"""
return base64.b64encode(image_data).decode("utf-8")
def format_session_id(session_type: str, session_id: str) -> str:
"""格式化会话 ID
Args:
session_type: 会话类型 ("user", "group")
session_id: 原始会话 ID
Returns:
格式化后的会话 ID
"""
return f"wecom_ai_bot_{session_type}_{session_id}"
def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
"""解析格式化的会话 ID
Args:
formatted_session_id: 格式化的会话 ID
Returns:
(会话类型, 原始会话ID)
"""
parts = formatted_session_id.split("_", 3)
if (
len(parts) >= 4
and parts[0] == "wecom"
and parts[1] == "ai"
and parts[2] == "bot"
):
return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
return "user", formatted_session_id
def safe_json_loads(json_str: str, default: Any = None) -> Any:
"""安全地解析 JSON 字符串
Args:
json_str: JSON 字符串
default: 解析失败时的默认值
Returns:
解析结果或默认值
"""
import json
try:
return json.loads(json_str)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
return default
def format_error_response(error_code: int, error_msg: str) -> str:
"""格式化错误响应
Args:
error_code: 错误码
error_msg: 错误信息
Returns:
格式化的错误响应字符串
"""
return f"Error {error_code}: {error_msg}"
async def process_encrypted_image(
image_url: str, aes_key_base64: str
) -> Tuple[bool, str]:
"""下载并解密加密图片
Args:
image_url: 加密图片的URL
aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
Returns:
Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
status 为 False 时 data 是错误信息
"""
# 1. 下载加密图片
logger.info("开始下载加密图片: %s", image_url)
try:
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=15) as response:
response.raise_for_status()
encrypted_data = await response.read()
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
error_msg = f"下载图片失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
# 2. 准备AES密钥和IV
if not aes_key_base64:
raise ValueError("AES密钥不能为空")
# Base64解码密钥 (自动处理填充)
aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
if len(aes_key) != 32:
raise ValueError("无效的AES密钥长度: 应为32字节")
iv = aes_key[:16] # 初始向量为密钥前16字节
# 3. 解密图片数据
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = cipher.decrypt(encrypted_data)
# 4. 去除PKCS#7填充 (Python 3兼容写法)
pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
if pad_len > 32: # AES-256块大小为32字节
raise ValueError("无效的填充长度 (大于32字节)")
decrypted_data = decrypted_data[:-pad_len]
logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
# 5. 转换为base64编码
base64_data = base64.b64encode(decrypted_data).decode("utf-8")
logger.info("图片已转换为base64编码编码后长度: %d", len(base64_data))
return True, base64_data

View File

@@ -184,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
return PlatformMetadata(
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
)
@override

View File

@@ -65,13 +65,16 @@ class AssistantMessageSegment:
role: str = "assistant"
def to_dict(self):
ret = {
ret: dict[str, str | list[dict]] = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
if self.tool_calls:
ret["tool_calls"] = self.tool_calls
tool_calls_dict = [
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
]
ret["tool_calls"] = tool_calls_dict
return ret
@@ -117,7 +120,14 @@ class ProviderRequest:
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
return (
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
f"image_count={len(self.image_urls or [])}, "
f"func_tool={self.func_tool}, "
f"contexts={self._print_friendly_context()}, "
f"system_prompt={self.system_prompt}, "
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
)
def __str__(self):
return self.__repr__()

View File

@@ -4,7 +4,7 @@ import os
import asyncio
import aiohttp
from typing import Dict, List, Awaitable
from typing import Dict, List, Awaitable, Callable, Any
from astrbot import logger
from astrbot.core import sp
@@ -109,7 +109,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -132,7 +132,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> None:
"""添加函数调用工具
@@ -220,7 +220,7 @@ class FunctionToolManager:
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:

View File

@@ -1,13 +1,18 @@
import asyncio
import traceback
from typing import List
from astrbot.core import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .provider import (
Provider,
STTProvider,
TTSProvider,
EmbeddingProvider,
RerankProvider,
)
from .register import llm_tools, provider_cls_map
from ..persona_mgr import PersonaManager
@@ -22,7 +27,7 @@ class ProviderManager:
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
self.providers_config: List = config["provider"]
self.providers_config: list = config["provider"]
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
@@ -30,15 +35,20 @@ class ProviderManager:
# 人格相关属性v4.0.0 版本后被废弃,推荐使用 PersonaManager
self.default_persona_name = persona_mgr.default_persona
self.provider_insts: List[Provider] = []
self.provider_insts: list[Provider] = []
"""加载的 Provider 的实例"""
self.stt_provider_insts: List[STTProvider] = []
self.stt_provider_insts: list[STTProvider] = []
"""加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = []
self.tts_provider_insts: list[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[EmbeddingProvider] = []
self.embedding_provider_insts: list[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map: dict[str, Provider] = {}
self.rerank_provider_insts: list[RerankProvider] = []
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -87,19 +97,31 @@ class ProviderManager:
)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
prov = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
prov, TTSProvider
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.SPEECH_TO_TEXT:
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov, STTProvider
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.CHAT_COMPLETION:
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov, Provider
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(self, provider_type: ProviderType, umo=None):
def get_using_provider(
self, provider_type: ProviderType, umo=None
) -> Provider | STTProvider | TTSProvider | None:
"""获取正在使用的提供商实例。
Args:
@@ -152,7 +174,11 @@ class ProviderManager:
async def initialize(self):
# 逐个初始化提供商
for provider_config in self.providers_config:
await self.load_provider(provider_config)
try:
await self.load_provider(provider_config)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get(
@@ -211,6 +237,8 @@ class ProviderManager:
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
@@ -303,12 +331,14 @@ class ProviderManager:
provider_metadata = provider_cls_map[provider_config["type"]]
try:
# 按任务实例化提供商
cls_type = provider_metadata.cls_type
if not cls_type:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -327,9 +357,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -345,7 +373,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
inst = cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
@@ -366,23 +394,25 @@ class ProviderManager:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type in [
ProviderType.EMBEDDING,
ProviderType.RERANK,
]:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
logger.error(traceback.format_exc())
logger.error(
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
)
raise Exception(
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
)
async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config["id"])
@@ -430,11 +460,17 @@ class ProviderManager:
)
if self.inst_map[provider_id] in self.provider_insts:
self.provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, Provider):
self.provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.stt_provider_insts:
self.stt_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, STTProvider):
self.stt_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.tts_provider_insts:
self.tts_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, TTSProvider):
self.tts_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] == self.curr_provider_inst:
self.curr_provider_inst = None

View File

@@ -1,4 +1,5 @@
import abc
import asyncio
from typing import List
from typing import AsyncGenerator
from astrbot.core.agent.tool import ToolSet
@@ -68,14 +69,15 @@ class Provider(AbstractProvider):
def get_keys(self) -> List[str]:
"""获得提供商 Key"""
return self.provider_config.get("key", [])
keys = self.provider_config.get("key", [""])
return keys or [""]
@abc.abstractmethod
def set_key(self, key: str):
raise NotImplementedError()
@abc.abstractmethod
def get_models(self) -> List[str]:
async def get_models(self) -> List[str]:
"""获得支持的模型列表"""
raise NotImplementedError()
@@ -202,6 +204,72 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度"""
...
async def get_embeddings_batch(
self,
texts: list[str],
batch_size: int = 16,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[list[float]]:
"""批量获取文本的向量,分批处理以节省内存
Args:
texts: 文本列表
batch_size: 每批处理的文本数量
tasks_limit: 并发任务数量限制
max_retries: 失败时的最大重试次数
progress_callback: 进度回调函数,接收参数 (current, total)
Returns:
向量列表
"""
semaphore = asyncio.Semaphore(tasks_limit)
all_embeddings: list[list[float]] = []
failed_batches: list[tuple[int, list[str]]] = []
completed_count = 0
total_count = len(texts)
async def process_batch(batch_idx: int, batch_texts: list[str]):
nonlocal completed_count
async with semaphore:
for attempt in range(max_retries):
try:
batch_embeddings = await self.get_embeddings(batch_texts)
all_embeddings.extend(batch_embeddings)
completed_count += len(batch_texts)
if progress_callback:
await progress_callback(completed_count, total_count)
return
except Exception as e:
if attempt == max_retries - 1:
# 最后一次重试失败,记录失败的批次
failed_batches.append((batch_idx, batch_texts))
raise Exception(
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
)
# 等待一段时间后重试,使用指数退避
await asyncio.sleep(2**attempt)
tasks = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
batch_idx = i // batch_size
tasks.append(process_batch(batch_idx, batch_texts))
# 收集所有任务的结果,包括失败的任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 检查是否有失败的任务
errors = [r for r in results if isinstance(r, Exception)]
if errors:
error_msg = (
f"{len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
)
raise Exception(error_msg)
return all_embeddings
class RerankProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:

View File

@@ -10,7 +10,7 @@ from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.func_tool_manager import ToolSet
from ..register import register_provider_adapter
from astrbot.core.provider.entities import LLMResponse
from typing import AsyncGenerator
@@ -33,7 +33,7 @@ class ProviderAnthropic(Provider):
)
self.chosen_api_key: str = ""
self.api_keys: List = provider_config.get("key", [])
self.api_keys: List = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
@@ -70,9 +70,13 @@ class ProviderAnthropic(Provider):
{
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"],
"input": (
json.loads(tool_call["function"]["arguments"])
if isinstance(
tool_call["function"]["arguments"], str
)
else tool_call["function"]["arguments"]
),
"id": tool_call["id"],
}
)
@@ -100,7 +104,7 @@ class ProviderAnthropic(Provider):
return system_prompt, new_messages
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
@@ -131,7 +135,7 @@ class ProviderAnthropic(Provider):
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
self, payloads: dict, tools: ToolSet | None
) -> AsyncGenerator[LLMResponse, None]:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -322,7 +326,7 @@ class ProviderAnthropic(Provider):
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
async def assemble_context(self, text: str, image_urls: List[str] | None = None):
"""组装上下文,支持文本和图片"""
if not image_urls:
return {"role": "user", "content": text}
@@ -355,9 +359,11 @@ class ProviderAnthropic(Provider):
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_data.split("base64,")[1]
if "base64," in image_data
else image_data,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
}
)

View File

@@ -0,0 +1,314 @@
import json
import asyncio
import aiohttp
import io
from typing import Dict, List, Any, AsyncGenerator
from astrbot.core import logger
class CozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
self.api_key = api_key
self.api_base = api_base
self.session = None
async def _ensure_session(self):
"""确保HTTP session存在"""
if self.session is None:
connector = aiohttp.TCPConnector(
ssl=False if self.api_base.startswith("http://") else True,
limit=100,
limit_per_host=30,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
timeout = aiohttp.ClientTimeout(
total=120, # 默认超时时间
connect=30,
sock_read=120,
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "text/event-stream",
}
self.session = aiohttp.ClientSession(
headers=headers, timeout=timeout, connector=connector
)
return self.session
async def upload_file(
self,
file_data: bytes,
) -> str:
"""上传文件到 Coze 并返回 file_id
Args:
file_data (bytes): 文件的二进制数据
Returns:
str: 上传成功后返回的 file_id
"""
session = await self._ensure_session()
url = f"{self.api_base}/v1/files/upload"
try:
file_io = io.BytesIO(file_data)
async with session.post(
url,
data={
"file": file_io,
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
response_text = await response.text()
logger.debug(
f"文件上传响应状态: {response.status}, 内容: {response_text}"
)
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
)
try:
result = await response.json()
except json.JSONDecodeError:
raise Exception(f"文件上传响应解析失败: {response_text}")
if result.get("code") != 0:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
logger.debug(f"[Coze] 图片上传成功file_id: {file_id}")
return file_id
except asyncio.TimeoutError:
logger.error("文件上传超时")
raise Exception("文件上传超时")
except Exception as e:
logger.error(f"文件上传失败: {str(e)}")
raise Exception(f"文件上传失败: {str(e)}")
async def download_image(self, image_url: str) -> bytes:
"""下载图片并返回字节数据
Args:
image_url (str): 图片的URL
Returns:
bytes: 图片的二进制数据
"""
session = await self._ensure_session()
try:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"下载图片失败,状态码: {response.status}")
image_data = await response.read()
return image_data
except Exception as e:
logger.error(f"下载图片失败 {image_url}: {str(e)}")
raise Exception(f"下载图片失败: {str(e)}")
async def chat_messages(
self,
bot_id: str,
user_id: str,
additional_messages: List[Dict] | None = None,
conversation_id: str | None = None,
auto_save_history: bool = True,
stream: bool = True,
timeout: float = 120,
) -> AsyncGenerator[Dict[str, Any], None]:
"""发送聊天消息并返回流式响应
Args:
bot_id: Bot ID
user_id: 用户ID
additional_messages: 额外消息列表
conversation_id: 会话ID
auto_save_history: 是否自动保存历史
stream: 是否流式响应
timeout: 超时时间
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/chat"
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
"auto_save_history": auto_save_history,
}
if additional_messages:
payload["additional_messages"] = additional_messages
params = {}
if conversation_id:
params["conversation_id"] = conversation_id
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
try:
async with session.post(
url,
json=payload,
params=params,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
# SSE
buffer = ""
event_type = None
event_data = None
async for chunk in response.content:
if chunk:
buffer += chunk.decode("utf-8", errors="ignore")
lines = buffer.split("\n")
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
if event_type and event_data:
yield {"event": event_type, "data": event_data}
event_type = None
event_data = None
elif line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_str = line[5:].strip()
if data_str and data_str != "[DONE]":
try:
event_data = json.loads(data_str)
except json.JSONDecodeError:
event_data = {"content": data_str}
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {str(e)}")
async def clear_context(self, conversation_id: str):
"""清空会话上下文
Args:
conversation_id: 会话ID
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/clear_context"
payload = {"conversation_id": conversation_id}
try:
async with session.post(url, json=payload) as response:
response_text = await response.text()
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
try:
return json.loads(response_text)
except json.JSONDecodeError:
raise Exception("Coze API 返回非JSON格式")
except asyncio.TimeoutError:
raise Exception("Coze API 请求超时")
except aiohttp.ClientError as e:
raise Exception(f"Coze API 请求失败: {str(e)}")
async def get_message_list(
self,
conversation_id: str,
order: str = "desc",
limit: int = 10,
offset: int = 0,
):
"""获取消息列表
Args:
conversation_id: 会话ID
order: 排序方式 (asc/desc)
limit: 限制数量
offset: 偏移量
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/list"
params = {
"conversation_id": conversation_id,
"order": order,
"limit": limit,
"offset": offset,
}
try:
async with session.get(url, params=params) as response:
response.raise_for_status()
return await response.json()
except Exception as e:
logger.error(f"获取Coze消息列表失败: {str(e)}")
raise Exception(f"获取Coze消息列表失败: {str(e)}")
async def close(self):
"""关闭会话"""
if self.session:
await self.session.close()
self.session = None
if __name__ == "__main__":
import os
import asyncio
async def test_coze_api_client():
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
try:
with open("README.md", "rb") as f:
file_data = f.read()
file_id = await client.upload_file(file_data)
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
bot_id=bot_id,
user_id="test_user",
additional_messages=[
{
"role": "user",
"content": json.dumps(
[
{"type": "text", "text": "这是什么"},
{"type": "file", "file_id": file_id},
],
ensure_ascii=False,
),
"content_type": "object_string",
},
],
stream=True,
):
print(f"Event: {event}")
finally:
await client.close()
asyncio.run(test_coze_api_client())

View File

@@ -0,0 +1,635 @@
import json
import os
import base64
import hashlib
from typing import AsyncGenerator, Dict
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://")
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: Dict[str, str] = {}
self.file_id_cache: Dict[str, Dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
else:
if data.startswith(("http://", "https://")):
# URL图片使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
else:
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self, image_url: str, session_id: str | None = None
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {str(e)}")
raise Exception(f"处理图片失败: {str(e)}")
async def _process_context_images(
self, content: str | list, session_id: str
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data, is_base64=image_data.startswith("data:image/")
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id}
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data, session_id
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}..."
)
continue
processed_content.append(
{"type": "image", "file_id": file_id}
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {str(e)}")
if isinstance(content, str):
return content
else:
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
else:
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{"role": "system", "content": system_prompt, "content_type": "text"}
)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content, user_id
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
}
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url, user_id
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url, is_base64=True
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
# 本地文件
if os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {str(e)}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
else:
# 纯文本
if prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
}
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {str(e)}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {str(e)}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
else:
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {str(e)}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self, session_id: str, page: int = 1, page_size: int = 10
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()

View File

@@ -1,15 +1,14 @@
import re
import asyncio
import functools
from typing import List
from .. import Provider, Personality
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
session_id=None,
image_urls=[],
func_tool=None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
assert isinstance(response, ApplicationResponse)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
),
)
output_text = response.output.get("text", "")
output_text = response.output.get("text", "") or ""
# RAG 引用脚标格式化
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
for ref in response.output.get("doc_references", []) or []:
ref_title = (
ref.get("title", "")
if ref.get("title")

View File

@@ -1,10 +1,22 @@
import os
import dashscope
import uuid
import asyncio
from dashscope.audio.tts_v2 import *
from ..provider import TTSProvider
import base64
import logging
import os
import uuid
from typing import Optional, Tuple
import aiohttp
import dashscope
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
try:
from dashscope.aigc.multimodal_conversation import MultiModalConversation
except (
ImportError
): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -26,16 +38,112 @@ class ProviderDashscopeTTSAPI(TTSProvider):
dashscope.api_key = self.chosen_api_key
async def get_audio(self, text: str) -> str:
model = self.get_model()
if not model:
raise RuntimeError("Dashscope TTS model is not configured.")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
self.synthesizer = SpeechSynthesizer(
model=self.get_model(),
os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model):
audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
else:
audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
if not audio_bytes:
raise RuntimeError(
"Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable."
)
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
with open(path, "wb") as f:
f.write(audio_bytes)
return path
def _call_qwen_tts(self, model: str, text: str):
if MultiModalConversation is None:
raise RuntimeError(
"dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models."
)
kwargs = {
"model": model,
"text": text,
"api_key": self.chosen_api_key,
"voice": self.voice or "Cherry",
}
if not self.voice:
logging.warning(
"No voice specified for Qwen TTS model, using default 'Cherry'."
)
return MultiModalConversation.call(**kwargs)
async def _synthesize_with_qwen_tts(
self, model: str, text: str
) -> Tuple[Optional[bytes], str]:
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes:
raise RuntimeError(
f"Audio synthesis failed for model '{model}'. {response}"
)
ext = ".wav"
return audio_bytes, ext
async def _extract_audio_from_response(self, response) -> Optional[bytes]:
output = getattr(response, "output", None)
audio_obj = getattr(output, "audio", None) if output is not None else None
if not audio_obj:
return None
data_b64 = getattr(audio_obj, "data", None)
if data_b64:
try:
return base64.b64decode(data_b64)
except (ValueError, TypeError):
logging.error("Failed to decode base64 audio data.")
return None
url = getattr(audio_obj, "url", None)
if url:
return await self._download_audio_from_url(url)
return None
async def _download_audio_from_url(self, url: str) -> Optional[bytes]:
if not url:
return None
timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
return await response.read()
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
logging.error(f"Failed to download audio from URL {url}: {e}")
return None
async def _synthesize_with_cosyvoice(
self, model: str, text: str
) -> Tuple[Optional[bytes], str]:
synthesizer = SpeechSynthesizer(
model=model,
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
audio = await asyncio.get_event_loop().run_in_executor(
None, self.synthesizer.call, text, self.timeout_ms
loop = asyncio.get_event_loop()
audio_bytes = await loop.run_in_executor(
None, synthesizer.call, text, self.timeout_ms
)
with open(path, "wb") as f:
f.write(audio)
return path
if not audio_bytes:
resp = synthesizer.get_response()
if resp and isinstance(resp, dict):
raise RuntimeError(
f"Audio synthesis failed for model '{model}'. {resp}".strip()
)
return audio_bytes, ".wav"
def _is_qwen_tts_model(self, model: str) -> bool:
model_lower = model.lower()
return "tts" in model_lower and model_lower.startswith("qwen")

Some files were not shown because too many files have changed in this diff Show More