Compare commits

...

131 Commits

Author SHA1 Message Date
Soulter
7d776e0ce2 chore: bump version to 4.3.5 2025-10-15 12:19:26 +08:00
Soulter
17df1692b9 fix: 修复 /alter_cmd reset scene <num> xxx 不可用的问题 2025-10-15 12:16:13 +08:00
Soulter
9ab652641d feat: 支持配置工具调用超时时间并适配 ModelScope 的 MCP Server 配置 (#3039)
* feat: 支持配置工具调用超时时间并适配 ModelScope 的 MCP Server 配置。

closes: #2939

* fix: Remove unnecessary blank lines in _quick_test_mcp_connection function
2025-10-15 12:06:57 +08:00
shangxue
9119f7166f feat: satori 适配器支持 video、reply 消息类型 (#3035)
* Update satori_event.py

* style: format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-15 10:45:35 +08:00
Soulter
da7d9d8eb9 feat: Add tutorial link for wecom_ai_bot platform 2025-10-15 10:42:31 +08:00
Soulter
80fccc90b7 feat: 支持接入企业微信智能机器人平台 (#3034)
* stage

* stage

* feat: 支持图片收发

* feat: add support for wecom_ai_bot in getPlatformIcon function
2025-10-14 23:20:56 +08:00
Soulter
dcebc70f1a chore: Add new auto-assign users to configuration 2025-10-14 12:16:22 +08:00
dependabot[bot]
259e7bc322 chore(deps): bump github/codeql-action in the github-actions group (#3032)
Bumps the github-actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action).


Updates `github/codeql-action` from 3 to 4
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/github/codeql-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: github/codeql-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-14 09:35:57 +08:00
Soulter
37bdb6c6f6 feat: 内置网页搜索功能支持接入百度 AI 搜索 (#3031)
* feat: 内置网页搜索功能支持接入百度 AI 搜索

* fix: 修正配置文件中的拼写错误,更新为正确的键名

* Fix Baidu AI Search initialization logic
2025-10-14 09:35:34 +08:00
Soulter
dc71afdd3f docs: Revise README for clarity and updated support info
Updated README.md to improve clarity and fix formatting issues. Removed outdated developer group information and added support details for new platforms and services.
2025-10-14 09:13:54 +08:00
Soulter
44638108d0 docs: readme 2025-10-14 08:53:23 +08:00
RC-CHN
93fcac498c feat: 添加并优化服务提供商独立测试功能 (#3024)
* feat: 添加并优化服务提供商独立测试功能

* feat: add small size to action buttons in ItemCard and ProviderPage for better UI consistency

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-13 13:03:20 +08:00
Soulter
79e2743aac chore: bump version to 4.3.3 2025-10-12 11:42:18 +08:00
anka
5e9c7cdd91 fix: 当没有填写 api key 时,设置为空字符串 (#2834)
* fix: 修复空key导致的无法创建Provider对象的问题

* style: format code

* Update astrbot/core/provider/provider.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-10-12 10:50:01 +08:00
Dt8333
6f73e5087d feat(core): 在新对话中重用先前的对话人格设置 (#3005)
* feat(core): reuse persona conf in new conversation

#2985

* refactor(core): simplify persona retrieval logic

* style: code format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-12 10:42:35 +08:00
Yaron
8c120b020e fix: 修复阿里云百炼平台 TTS 下接入 CosyVoice V2, Qwen TTS 生成报错的问题 (#2964)
* fix: 修复了CosyVoice V2,Qwen TTS生成报错的问题。Fixed compatability problems with CosyVoice V2, Qwen TTS.

* fix: 将urlopen的同步请求替换为aiohttp的异步请求以下载音频

* fix: cozyvoice 报错显示

* fix: 添加阿里云百炼 TTS API Key 获取提示信息

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-12 01:03:06 +08:00
Dt8333
12fc6f9d38 fix(LTM): fix LTM not removed when removing conversation (#3002)
#2983
2025-10-12 00:16:42 +08:00
Dt8333
a6e8483b4c fix: 修复session-management中人格错误的显示为默认人格的问题 (#3000)
* fix: 修复session-management中人格错误的显示为默认人格的问题

#2985

* refactor: 使用命名表达式简化赋值和条件

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* style: format edited code with ruff

format code edited by sourcery-ai

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-12 00:12:04 +08:00
Soulter
7191d28ada fix: 启动了 TTS 但未配置 TTS 模型时,At 和 Reply 发送人无效
fixes: #2996
2025-10-10 12:11:03 +08:00
Soulter
e6b5e3d282 feat: tokenpony provider 2025-10-09 16:00:31 +08:00
ctrlkk
1413d6b5fe fix: 让事件钩子被暂停时跳出循环,而不是继续执行 (#2989) 2025-10-09 15:01:45 +08:00
ctrlkk
dcd8a1094c feat: 优化 SQLite 参数配置,对话和会话管理增加输入防抖机制 (#2969)
* feat: 优化 SQLite 数据库初始化设置并增强会话搜索功能,会话管理增加输入防抖

* fix: adjust SQLite cache and mmap size

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-06 17:13:53 +08:00
Futureppo
e64b31b9ba fix: Correct default modalities for DeepSeek provider (#2963)
* 更新 package.json

* 更新 ExtensionPage.vue

* fix(provider): Correct default modalities for DeepSeek provider
2025-10-06 16:30:05 +08:00
Dt8333
080f347511 feat: clean browser cache after update (#2958)
* feat: clean browser cache after update

* fix: move const to module

* fix: remove self prefix (a stupid mistake)
2025-10-06 16:29:18 +08:00
Dt8333
eaaff4298d fix(Python-Interpreter): fix incorrect file read method (#2970)
fix getting file by property(Sync) in an async handler

#2960
2025-10-06 16:12:05 +08:00
Soulter
dd5a02e8ef chore: bump version to 4.3.2 2025-10-05 01:01:13 +08:00
Soulter
3211ec57ee fix: handle Google search initialization and errors gracefully 2025-10-05 00:55:47 +08:00
Soulter
6796afdaee fix: googlesearch 2025-10-05 00:54:24 +08:00
Soulter
cc6fe57773 fix: on_tool_end无法获得工具返回的结果 (#2956)
fixes: #2940
2025-10-05 00:37:51 +08:00
Soulter
1dfc831938 fix: 修复 reset 没有清除群聊上下文感知数据的问题 (#2954) 2025-10-05 00:05:42 +08:00
Futureppo
cafeda4abf feat: 为插件市场的搜索增加拼音与首字母搜索功能 (#2936)
* 更新 package.json

* 更新 ExtensionPage.vue
2025-10-03 09:42:57 +08:00
Soulter
d951b99718 fix: 发送阶段将 Plain 为空的消息段移除 2025-10-03 00:45:07 +08:00
Soulter
0ad87209e5 chore: bump version to 4.3.1 2025-10-02 17:25:09 +08:00
Soulter
1b50c5404d fix: enhance knowledge base plugin status check to handle empty data response 2025-10-02 17:25:00 +08:00
Soulter
3007f67cab fix: update Dockerfile to remove npm installation and streamline package setup
closes: #2284
2025-10-02 16:59:11 +08:00
Soulter
ee08659f01 chore: bump version to 4.3.0 2025-10-02 16:37:54 +08:00
Soulter
baf5ad0fab fix: 修复接入智谱提供商后,工具调用无限循环的问题,并停止支持 glm-4v-flash (#2931)
fixes: #2912
2025-10-02 16:03:24 +08:00
kterna
8bdd748aec feat: 支持注册消息平台适配器的 logo (#2109)
* feat: 添加平台适配器 logo 支持

* 优化平台logo注册逻辑,增加缓存机制并支持并行处理

* 去除判断绝对路径

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-02 14:36:15 +08:00
Soulter
cef0c22f52 feat: update prompt prefix handling to support placeholder replacement 2025-10-02 14:20:52 +08:00
Soulter
13d3fc5cfe fix: fix type checking error and op, deop, wl, dwl command 2025-10-02 00:18:12 +08:00
Soulter
b91141e2be fix: add plugin activation check and corresponding messages in Knowledge Base 2025-10-01 22:14:03 +08:00
Soulter
f8a4b54165 fix: 修复插件指令注解为联合类型时处理异常的问题 (#2925)
* fix: 修复插件指令注解为联合类型时处理异常的问题

* fix: 修复参数类型检查以支持 typing.Union

* Update astrbot/core/star/filter/command.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/star/filter/command.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: 修复参数类型检查以支持 typing.Union 的处理逻辑

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-01 21:46:49 +08:00
Soulter
afe007ca0b refactor: 优化 packages/astrbot 内置插件的代码结构以提高可维护性和可读性 (#2924)
* refactor: code structure for improved readability and maintainability

* style: ruff format

* Update packages/astrbot/commands/provider.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update packages/astrbot/commands/persona.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update packages/astrbot/commands/llm.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update packages/astrbot/commands/conversation.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* fix: improve error handling message formatting in key switching

* fix: update LLM command to use safe get for provider settings

* feat: implement ProcessLLMRequest class for handling LLM requests and persona injection

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-10-01 21:29:15 +08:00
Soulter
8a9a044f95 fix: 修复注册指令组指令时的 Pyright 类型检查提示 (#2923) 2025-10-01 20:03:04 +08:00
u0_ani-nya.com
5eaf03e227 perf: 对于 Telegram 群聊,将回复机器人的消息视为唤醒机器人 (#2926)
* reply as at for tg

Add handling for bot replies in group messages.

* style: type checking and ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-10-01 19:04:37 +08:00
Seayon
a8437d9331 feat: 支持在 Telegram 和飞书下请求 LLM 前预表态功能 (#2737)
*  feat(platform): 为 Telegram 和飞书添加消息表情回应功能

支持在收到命令时自动添加表情回应,提升用户交互体验
新增平台特异配置项,允许自定义启用状态和表情列表

* Update astrbot/core/platform/astr_message_event.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* style: ruff format

* fix: 优化平台特异配置的预回应表情处理逻辑

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-30 17:29:34 +08:00
晴空
e0392fa98b fix: 用 mi-googlesearch-python 库代替失效的 googlesearch-python 库 (#2909)
* googlesearch-python库失效,用mi-googlesearch-python库平替,恢复谷歌搜索

* Update googlesearch-python dependency version
2025-09-29 12:54:16 +08:00
ctrlkk
68ff8951de feat: 添加分页和搜索功能以获取会话列表,优化前端与后端的数据交互 (#2906)
* feat: 添加分页和搜索功能以获取会话列表,优化前端与后端的数据交互

* fix: 修复会话计数显示,使用总项数替代会话数组长度

* fix: 将参数类型和名称与实现内容匹配。

* perf: convert for loop into list comprehension

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* fix: type checking error

* fix: 优化 persona_id 的获取逻辑

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-28 23:25:30 +08:00
KroMiose
9c6b31e71c Update README.md (#2904) 2025-09-28 14:50:02 +08:00
Soulter
50f74f5ba2 fix: 修复"开启 TTS 时同时输出语音和文字内容"功能不可用的问题 (#2900)
fixes: #2844
2025-09-28 10:48:57 +08:00
Soulter
b9de2aef60 chore: bump version to 4.2.1 2025-09-27 23:36:25 +08:00
Soulter
7a47598538 fix: 修复指令无法使用的问题
fixes: #2897
2025-09-27 23:35:35 +08:00
Soulter
3c8c28ebd5 chore: bump version to 4.2.0 2025-09-27 20:45:50 +08:00
Soulter
524285f767 feat: add cancel button with localized text to AddNewPlatform and update close button in AddNewProvider
fixes: #2889
2025-09-27 20:41:45 +08:00
Soulter
c2a34475f1 feat: 支持删除指定会话以及部分会话管理优化 (#2895)
* feat: add toast notification system with snackbar component

* feat: add session deletion functionality

* feat: support batch operations for updating session persona, provider, LLM, and TTS statuses

fix: #2263

* feat: 修复对话状态关闭,删除对话管理库会导致对话无法恢复

fixes: #2309
2025-09-27 20:36:30 +08:00
Soulter
a69195a02b fix: webchat streaming queue interrupted after user closing tab (#2892)
* feat: add toast notification system with snackbar component

* feat: enhance chat functionality with conversation running state and notifications

* fix: update bot message avatar rendering during streaming

* feat: implement conversation tracking context manager for webchat

* fix: update conversation tracking to remove conversation ID on exit
2025-09-27 17:57:12 +08:00
RC-CHN
19d7438499 fix: unit tests (#2760)
* fix:修复了main和plugin_manager部分单元测试

* fix: 修复了dashboard部分测试

* remove: 删除暂无用的配置测试脚本

* perf:拆分插件增查删改为独立的单元测试

* refactor: 重构插件管理器测试,使用临时环境隔离测试实例

* test: 增加对仪表板文件检查的单元测试,涵盖不同情况

* style: format code

* remove: 删除未使用的导入语句

* delete: remove unused test file for pipeline

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:43:04 +08:00
anka
ccb380ce06 feat: 支持接入 Coze (#2858)
* feat: 适配 coze 供应商
1. 支持文件上传
2. 支持多模态
3. 支持流式传输
4. 支持 API 端的上下文保存历史记录
5. 支持类似 dify 的 forget 接口

* style: format code

* fix: type checking error

* fix: 修复:
1. 使用coze api端的上下文时, 现在不会重复传递上下文
2. 使用 AstrBot 的上下文时, 正确处理其中的图片信息
3. 上传图片时, 提供一个非持久化的缓存避免重复上传(在解析上下文并将文件转化为file_id传递给coze api时, 如果没有缓存会导致很多的网络资源浪费)
4. 修复reset等指令不能正确重置上下文的问题

* fix: 移除某些地方多余的针对 dify 的断言, 以兼容 Coze

* style: 修改配置项显示/webchat平台对于非预期的类型的处理

* fix: 让conversation_id放到请求中正确的位置

* refactor: extract coze api client

* refactor: improve image processing logic in ProviderCoze

* chore: remove file ext guessing

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:23:29 +08:00
Ding Jiatong
a35c439bbd fix: 使用增量解码器修复 Dify 流式返回结果偶现的解码错误 (#2888)
* fix: 修复linux下utf-8解码错误的问题

* feat: use incremental decoder

* fix: add type hint for response parameter in _stream_sse and refactor file upload method

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 23:04:58 +08:00
Soulter
09d1f96603 fix: 修复 /alter_cmd 指令无法控制指令组、子指令组和子指令组下子指令的问题 (#2873)
* fix: revert changes in command_group.py at 782c036 to fix command group permission check

* fix: 不传递 GroupCommand handler

* perf: alter_cmd 指令支持对子指令、指令组进行配置

* chore: remove test commands and subcommands from test_group

* chore: add cache for complete command names list in CommandFilter and CommandGroupFilter

---------

Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 14:16:50 +08:00
鸦羽
26aa18d980 Merge pull request #2881 from Raven95676/fix/2879
fix: add missing id field
2025-09-26 11:31:28 +08:00
Raven95676
d10b542797 chore: format 2025-09-26 11:05:32 +08:00
Raven95676
ce4e4fb8dd fix: add missing id field 2025-09-26 10:59:11 +08:00
Soulter
8f4a31cf8c chore: bump version to 4.1.7 2025-09-23 22:16:36 +08:00
Soulter
23549f13d6 Feature: 支持批量删除对话历史 (#2859)
* feat: 支持批量删除对话

closes: #2784

* feat: 添加加载状态禁用功能,优化用户交互体验
2025-09-23 22:10:56 +08:00
Soulter
869d11f9a6 perf: 优化验证配置时的性能,移除配置隐式类型转换
fixes: #2646
2025-09-23 21:04:14 +08:00
Soulter
02e73b82ee fix: 修复无法打开更新对话框的问题 2025-09-23 20:29:10 +08:00
Soulter
f85f87f545 feat: WebChat 支持手动填写模型名
closes: #2830
2025-09-23 15:32:54 +08:00
Soulter
1fff5713f3 refactor: 解耦 PlatformPage 和 ProviderPage 的部分组件 2025-09-23 15:32:54 +08:00
Soulter
8453ec36f0 docs: Revise links for documentation and blog in README
Updated links in the README for documentation and blog.
2025-09-23 14:12:05 +08:00
Soulter
d5b3ce8424 fix: update download_dashboard to log specific dashboard release URLs 2025-09-23 13:10:33 +08:00
Soulter
80cbbfa5ca chore: bump version to 4.1.6 2025-09-23 13:02:06 +08:00
Soulter
9177bb660f fix: improve error handling in run_agent for streaming responses 2025-09-23 10:34:24 +08:00
Soulter
a3df39a01a perf: unified button styles
closes: #2748
2025-09-23 10:27:52 +08:00
Soulter
25dce05cbb refactor: improve webchat UI (#2853) 2025-09-23 10:19:26 +08:00
Soulter
1542ea3e03 fix: context.get_provider_by_id issue 2025-09-22 17:22:50 +08:00
Soulter
6084abbcfe feat: add user_id search capability in get_filtered_conversations 2025-09-21 22:45:55 +08:00
Soulter
ed19b63914 chore: bump version to v4.1.5 2025-09-21 21:47:14 +08:00
Soulter
4efeb85296 chore: remove uv.lock file 2025-09-21 21:47:06 +08:00
shangxue
fc76665615 feat: Satori适配器引用消息无法正确识别 (#2686)
* Update PlatformPage.vue

* Update PlatformPage.vue

* Update PlatformPage.vue

* Update satori_adapter.py

* Update satori_event.py

* Update default.py

* Update satori_adapter.py

* Update satori_adapter.py

* style: format code

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-21 21:45:35 +08:00
Soulter
3a044bb71a fix: 修复 Telegram 下流式传输时,第一次输出的内容会被覆盖掉的问题 (#2838)
fixes: #2481
2025-09-21 21:24:47 +08:00
Soulter
cddd606562 perf: 优化 ExtensionPage 2025-09-21 21:10:03 +08:00
Soulter
7a5bc51c11 fix: 识别引用消息的图片时优先使用默认图片转述提供商 (#2836)
* fix: 识别引用消息的图片时优先使用默认图片转述提供商

closes: #2821

* fix: 添加日志记录以处理未找到图片标题提供者的情况

* style: format code
2025-09-21 20:55:32 +08:00
Soulter
9f939b4b6f fix: 修复对话管理页面的关键词搜索功能失效的问题并优化一些 UI 样式 (#2837)
* fix: 修复对话管理页面的关键词搜索功能失效的问题并优化一些 UI 样式

fixes: #2782

* style: format code

* fix: remove debug print statements from conversation retrieval methods
2025-09-21 20:55:15 +08:00
Soulter
80a86f5b1b fix: 修复 astrbot.core.star 等包下的 type checking error (#2787)
* fix: 修复 astrbot.core.star 等包下的 type checking error

* refactor: improve type checking and annotations

* chore: ruff format
2025-09-21 18:10:04 +08:00
yitaikarma
a0ce1855ab fix: 优化统计页内存占用和消息数据趋势的样式 (#2826)
* fix: 调整统计页内存占用和消息趋势分析的布局,优化响应式显示

* fix: 隐藏增长率为零时的趋势图标
2025-09-21 17:06:47 +08:00
anka
a4b43b884a fix: 修复aiocqhttp适配器at会获取群昵称而消息不会获取的逻辑不一致 (#2769)
* fix: 修复at会获取群昵称而消息不会获取的逻辑不一致

* style: format code
2025-09-19 13:04:51 +08:00
PaloMiku
824c0f6667 feat: 新增 Misskey 平台适配器 (#2774)
* feat: add Misskey platform adapter

* fix: 修复 Misskey 配置项的大小写问题

* feat: 添加消息链序列化功能和可见性解析逻辑

* chore: 删除损坏的 Misskey 平台适配器工具函数文件

* docs: 更新 Misskey 消息适配器设置描述信息

* feat: Misskey 单用户连续上下文对话支持

* feat: 为 Astrbot 添加 Misskey 平台适配器的 ID 配置

* feat: 重构 Misskey 平台适配器,提取通用工具函数并优化消息处理逻辑

* refactor: 清理 Misskey 平台适配器和 API 代码,移除冗余注释

* fix: 修复了使用中和使用者反馈的多个问题

* fix: 修改提及格式,确保提及在新行开始,提升帖子美观和易读性。

* feat: 添加默认可见性和本地仅限设置,优化 Misskey 平台适配器的配置

* fix: 更新 Misskey 平台适配器配置,使用前缀以防止和其他适配器未来可能的冲突问题

* chore: rename 'misskey' to 'Misskey' in config

* feat: Misskey 适配器添加聊天消息响应功能,重构接收和发送逻辑为 Websockets 处理

* fix: 增强 Misskey WebSocket 消息日志输出

* refactor: 优化 Misskey 适配器的消息处理和日志输出

* fix: 增强 Misskey WebSocket 重连接逻辑

* feat: 增强 Misskey 适配器的消息处理,支持房间消息和相关功能,重构通用函数,清理代码重复冗余

* fix: 不屏蔽唤醒前缀对默认 LLM 的唤醒

* fix: 透传所有的群聊消息事件

* fix: 修复 message_type

* perf: 实现 send_streaming 以支援流式请求

* docs(README): update README.md

* fix: super().send(message) 被忽略

* fix: 修正 session 结构

: 作为分隔符可能会导致 umo 组装出现问题

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-18 23:34:41 +08:00
Soulter
a030fe8491 feat: add audioop-lts dependencies (#2809)
pydub needs audioop as a requirement but this builtin package has been removed in 3.13
2025-09-18 23:32:04 +08:00
Soulter
3a9429e8ef fix: on_tool_end hook unavailable 2025-09-17 15:48:57 +08:00
anka
c4eb1ab748 chore: bump version to 4.1.4 2025-09-16 20:09:11 +08:00
anka
29ed19d600 Merge pull request #2783 from AstrBotDevs/revert-2778-fix-handler-type
Revert "fix: parameter type/default handling in CommandFilter"
2025-09-16 20:01:23 +08:00
anka
0cc65513a5 Revert "fix: parameter type/default handling in CommandFilter" 2025-09-16 20:01:05 +08:00
Soulter
debc048659 chore: bump version to 4.1.3 2025-09-16 13:16:21 +08:00
邹永赫
92f5c918dd Merge pull request #2778 from MliKiowa/fix-handler-type
fix: parameter type/default handling in CommandFilter
2025-09-16 13:43:53 +09:00
手瓜一十雪
9519f1e8e2 fix: parameter type/default handling in CommandFilter
Adjusts logic to prioritize type annotations over default values when setting handler_params in CommandFilter. This ensures that parameter types are correctly inferred when available.
2025-09-16 11:49:27 +08:00
Soulter
a8f874bf05 fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 (#2757) 2025-09-16 10:45:39 +08:00
anka
9d9917e45b feat: 增加群名称识别到 system prompt, 并提供相应的配置 (#2770)
* feat🤖: 增加群名称识别到system prompt, 并提供相应的配置

* feat: 优化实现方式, 重构AstrBotMessage, 向后兼容

* style: format
2025-09-16 10:23:08 +08:00
Soulter
91ee0a870d fix: handle image value correctly for mcp BlobResourceContents (#2753) 2025-09-16 08:22:18 +08:00
dependabot[bot]
6cbbffc5a9 chore(deps): bump the github-actions group with 2 updates (#2771)
Bumps the github-actions group with 2 updates: [actions/checkout](https://github.com/actions/checkout) and [actions/setup-python](https://github.com/actions/setup-python).


Updates `actions/checkout` from 4 to 5
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

Updates `actions/setup-python` from 5 to 6
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-16 08:19:31 +08:00
Yokami
8f26fd34d1 feat: add copy button for service providers (#2767) 2025-09-15 22:17:00 +08:00
Soulter
fda655f6d7 fix: 修复配置默认 TTS 或者 STT 模型之后仍无法生效的问题 (#2758)
fixes: #2731
2025-09-15 22:08:40 +08:00
Soulter
a663d6509b chore: bump version to 4.1.2 2025-09-14 21:07:36 +08:00
Soulter
9ec8839efa perf: 检查服务提供商可用性时跳过未启用的提供商 2025-09-14 21:01:32 +08:00
Soulter
a7a0350eb2 fix: 平台配置下的「内容安全」组无法生效 (#2751) 2025-09-14 20:25:53 +08:00
Soulter
39a7a0d960 fix: revert "feat: 兼容指令名和第一个参数之间没有空格的情况 (#2650)" for command issue
This reverts commit 9bfa726107.
2025-09-14 19:31:15 +08:00
Soulter
7740e1e131 ci: add ci stage of code format checking (#2750)
* style: ruff format

* ci(dashboard-ci): ensure GitHub Release action only runs on push events

* ci(code-format): ruff format and ruff check
2025-09-14 18:05:58 +08:00
Soulter
9dce1ed47e chore(github): revise PR template
Updated the pull request template to improve clarity and fix formatting issues.
2025-09-14 14:44:46 +08:00
Soulter
e84a00d3a5 fix: 修复多配置文件配置的不同人格无法生效的问题 (#2739)
fixes: #2724
2025-09-14 14:09:46 +08:00
anka
88a944cb57 chore(github): 优化 PR 模板 2025-09-14 12:58:34 +08:00
Soulter
20c32e72cc chore: bump version to 4.1.1 2025-09-13 16:19:40 +08:00
Soulter
4788c20816 fix: model variable referenced before assignment 2025-09-13 16:18:22 +08:00
Soulter
e83fc570a4 chore: bump version to 4.1.0 2025-09-13 13:31:49 +08:00
Yokami
e841b6af88 feat: 支持在 WebUI 自定义 OpenAI API extra_body 参数 (#2719)
* feat: 支持OPENAI系 模型的自定义标头,以解决qwen模型无法使用的问题

* fix: 修复AI说的问题

* fix: 布尔开关向右对齐
2025-09-13 13:23:49 +08:00
Dt8333
ea6f209557 fix: 修复LLM仍会调用已禁用的工具的问题 (#2729)
* fix: 修复LLM调用已禁用的工具

* feat: 修改工具禁用判断位置,提高效率
未设置可用工具时仍旧循环判断
设置可用工具后在获取工具时即判断
2025-09-12 21:36:10 +08:00
Zhalslar
9bfa726107 feat: 兼容指令名和第一个参数之间没有空格的情况 (#2650)
插件中@filter.command的指令在用户输入“命令+参数” 无空格隔开时无法处理,但只要稍微改动几行代码就可以兼容
2025-09-12 15:40:37 +08:00
shangxue
d24902c66d feat: 添加 --webui-dir 启动参数以支持指定 WebUI 构建文件目录 (#2680)
* Update main.py

* Update server.py

* Update main.py

* Update main.py

* Update main.py

* Update initial_loader.py

* Update server.py

* Update main.py

* chore: update webui_dir type hint and improve dashboard file check logic

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-12 15:15:29 +08:00
RC-CHN
72aea2d3f3 feat: 允许添加多个 tavily API Key 进行轮询 (#2725)
* feat: 允许添加多个tavily API Key进行轮询

* perf: 并发安全的从列表中获取并轮换Tavily API密钥

* fix: 自动迁移旧版 websearch_tavily_key 为列表格式并保存
2025-09-12 15:03:47 +08:00
RC-CHN
dc9612d564 fix: 修复自定义文转图模板更新版本后会被覆盖的问题 (#2677)
* perf: 更新模板管理逻辑,在data目录中管理用户自定义模板,优化热重载逻辑

* refactor: 优化模板管理逻辑,重构模板复制和初始化流程,增强用户模板管理功能

* chore:移除无用注释

* remove:移除了t2i部分中不会走到的异常

* style: format code

* fix: trim whitespace from template names in create, update, and delete operations

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-12 13:34:07 +08:00
Soulter
1770556d56 fix: 修复工具调用时的 content 内容在重新加载后没有显示在 webchat 的问题 (#2727) 2025-09-12 13:05:33 +08:00
Soulter
888fb84aee fix: 修复 WebChat 下,Agent 长时任务时,SSE 连接自动断开的问题 2025-09-12 13:04:27 +08:00
Soulter
d597fd056d fix: 修复知识库不能创建的问题 2025-09-11 17:27:57 +08:00
quirrel
dea0ab3974 fix: 解决插件页表格视图中,点击状态字段表头排序不起作用的问题 (#2714) 2025-09-11 16:20:33 +08:00
Soulter
da6facd7d7 docs: 修复开发者群组错误 2025-09-11 12:44:59 +08:00
Soulter
bb8ab5f173 docs: update readme 2025-09-11 10:40:30 +08:00
Soulter
ac8a541059 docs: remove message stat badge
Removed the old dynamic JSON badge for message volume.
2025-09-11 10:36:18 +08:00
Soulter
0e66771f0e docs: revise acknowledgments and add similar projects
Updated project acknowledgments and added links to similar open-source bot projects.
2025-09-11 10:35:11 +08:00
Soulter
d3a295a801 ci: add auto_assign.yml for auto PR reviewer assignment 2025-09-10 13:21:34 +08:00
shangxue
f2df771771 fix: 修复 Satori 适配器教程链接 (#2668)
* Update PlatformPage.vue

* Update PlatformPage.vue
2025-09-09 21:59:06 +08:00
dependabot[bot]
7b72cd87a5 chore(deps): bump the github-actions group with 2 updates (#2674)
Bumps the github-actions group with 2 updates: [actions/setup-python](https://github.com/actions/setup-python) and [actions/stale](https://github.com/actions/stale).


Updates `actions/setup-python` from 5 to 6
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v5...v6)

Updates `actions/stale` from 9 to 10
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v9...v10)

---
updated-dependencies:
- dependency-name: actions/setup-python
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/stale
  dependency-version: '10'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-09 08:46:04 +08:00
anka
9431efc6d1 feat: 增加 on_platform_loaded 钩子以在消息平台适配器实例化完成后触发 (#2651)
* feat⚒️: 增加平台加载时的钩子

* fix: 补充api

* fix: 只捕获Exception
2025-09-09 08:44:37 +08:00
210 changed files with 13900 additions and 10650 deletions

View File

@@ -1,19 +1,46 @@
<!-- 如果有的话,指定这个 PR 解决的 ISSUE -->
解决了 #XYZ
<!-- 如果有的话,指定 PR 旨在解决的 ISSUE 编号。 -->
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
### Motivation
fixes #XYZ
<!--解释为什么要改动-->
---
### Modifications
### Motivation / 动机
<!--简单解释你的改动-->
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
### Check
### Modifications / 改动点
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 👀 我的更改经过良好的测试
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。
- [ ] 😮 我的更改没有引入恶意代码
### Verification Steps / 验证步骤
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤例如1. 导航到... 2. 点击...)。-->
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
### Screenshots or Test Results / 运行截图或测试结果
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
### Compatibility & Breaking Changes / 兼容性与破坏性变更
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
---
### Checklist / 检查清单
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.

38
.github/auto_assign.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
# Set to true to add reviewers to pull requests
addReviewers: true
# Set to true to add assignees to pull requests
addAssignees: false
# A list of reviewers to be added to pull requests (GitHub user name)
reviewers:
- Soulter
- Raven95676
- Larch-C
- anka-afk
- advent259141
- Fridemn
- LIghtJUNction
# - zouyonghe
# A number of reviewers added to the pull request
# Set 0 to add all the reviewers (default: 0)
numberOfReviewers: 2
# A list of assignees, overrides reviewers if set
# assignees:
# - assigneeA
# A number of assignees to add to the pull request
# Set to 0 to add all of the assignees.
# Uses numberOfReviewers if unset.
# numberOfAssignees: 2
# A list of keywords to be skipped the process that add reviewers if pull requests include it
skipKeywords:
- wip
- draft
# A list of users to be skipped by both the add reviewers and add assignees processes
# skipUsers:
# - dependabot[bot]

View File

@@ -73,7 +73,7 @@ jobs:
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'

34
.github/workflows/code-format.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Code Format Check
on:
pull_request:
branches: [ master ]
push:
branches: [ master ]
jobs:
format-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
- name: Install UV
run: pip install uv
- name: Install dependencies
run: uv sync
- name: Check code formatting with ruff
run: |
uv run ruff format --check .
- name: Check code style with ruff
run: |
uv run ruff check .

View File

@@ -60,7 +60,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -88,6 +88,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4
with:
category: "/language:${{matrix.language}}"

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
- name: Install dependencies
run: |

View File

@@ -37,6 +37,7 @@ jobs:
!dist/**/*.md
- name: Create GitHub Release
if: github.event_name == 'push'
uses: ncipollo/release-action@v1
with:
tag: release-${{ github.sha }}

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v9
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'

4
.gitignore vendored
View File

@@ -30,4 +30,6 @@ packages/python_interpreter/workplace
.conda/
.idea
pytest.ini
.astrbot
.astrbot
uv.lock

View File

@@ -4,8 +4,6 @@ WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
nodejs \
npm \
gcc \
build-essential \
python3-dev \
@@ -13,23 +11,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libssl-dev \
ca-certificates \
bash \
ffmpeg \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y curl gnupg && \
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
RUN uv pip install socksio uv pilk --no-cache-dir --system
# 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
EXPOSE 6185
EXPOSE 6185
EXPOSE 6186
CMD [ "python", "main.py" ]

111
README.md
View File

@@ -1,5 +1,3 @@
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
@@ -13,17 +11,17 @@
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7日消息量&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路线图</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
## 主要功能
@@ -35,7 +33,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
## 部署方式
#### Docker 部署
#### Docker 部署(推荐 🥳)
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
@@ -100,8 +98,7 @@ uv run main.py
- 3 群630166526
- 5 群822130018
- 6 群753075035
- 开发者群753075035
- 开发者群备份295657329
- 开发者群:975206796
### Telegram 群组
@@ -111,49 +108,82 @@ uv run main.py
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ⚡ 消息平台支持情况
**官方维护**
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方机器人接口) | ✔ |
| QQ(官方平台) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企业微信 | ✔ |
| 企微应用 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | ✔ |
| Satori | ✔ |
| Misskey | ✔ |
| 企微智能机器人 | 将支持 |
| Whatsapp | 将支持 |
| LINE | 将支持 |
**社区维护**
| 平台 | 支持性 |
| -------- | ------- |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
## ⚡ 提供商支持情况
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | 文本生成 | |
| Google Gemini | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
**大模型服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | |
| Google Gemini | ✔ | |
| Moonshot AI | ✔ | |
| 智谱 AI | ✔ | |
| DeepSeek | ✔ | |
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
| 硅基流动 | ✔ | |
| PPIO 派欧云 | ✔ | |
| ModelScope | ✔ | |
| OneAPI | ✔ | |
| Dify | ✔ | |
| 阿里云百炼应用 | ✔ | |
| Coze | ✔ | |
**语音转文本服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| Whisper | ✔ | 支持 API、本地部署 |
| SenseVoice | ✔ | 本地部署 |
**文本转语音服务**
| 名称 | 支持性 | 备注 |
| -------- | ------- | ------- |
| OpenAI TTS | ✔ | |
| Gemini TTS | ✔ | |
| GSVI | ✔ | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | GPT-Sovits |
| FishAudio | ✔ | |
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | |
| Azure TTS | ✔ | |
| Minimax TTS | ✔ | |
| 火山引擎 TTS | ✔ | |
## ❤️ 贡献
@@ -173,7 +203,6 @@ pip install pre-commit
pre-commit install
```
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
@@ -182,25 +211,21 @@ pre-commit install
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
此外,本项目的诞生离不开以下开源项目:
此外,本项目的诞生离不开以下开源项目的帮助
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
</details>
_私は、高性能ですから!_

View File

@@ -7,6 +7,7 @@ from astrbot.core.star.register import (
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_platform_loaded as on_platform_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
@@ -41,6 +42,7 @@ __all__ = [
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_platform_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",

View File

@@ -124,15 +124,17 @@ def build_plug_list(plugins_dir: Path) -> list:
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
result.append({
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
})
result.append(
{
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
}
)
# 获取在线插件列表
online_plugins = []
@@ -142,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list:
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append({
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
})
online_plugins.append(
{
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
}
)
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)

View File

@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
tools: list[str | FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -40,8 +40,15 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
timeout = cfg.get("timeout", 10)
try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
if transport_type == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
@@ -92,7 +99,7 @@ class MCPClient:
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
@@ -121,7 +128,14 @@ class MCPClient:
if not success:
raise Exception(error_msg)
if cfg.get("transport") != "streamable_http":
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
if transport_type != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
@@ -134,7 +148,7 @@ class MCPClient:
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
@@ -159,7 +173,7 @@ class MCPClient:
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
@@ -198,6 +212,8 @@ class MCPClient:
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict):
chain: MessageChain

View File

@@ -14,4 +14,5 @@ class ContextWrapper(Generic[TContext]):
context: TContext
event: AstrMessageEvent
NoContext = ContextWrapper[None]

View File

@@ -198,6 +198,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
if not func_tool:
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: 未找到工具 {func_tool_name}",
)
)
continue
try:
await self.agent_hooks.on_tool_start(
self.run_context, func_tool, func_tool_args
@@ -210,9 +221,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
run_context=self.run_context,
**func_tool_args,
)
async for resp in executor:
_final_resp: CallToolResult | None = None
async for resp in executor: # type: ignore
if isinstance(resp, CallToolResult):
res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -258,7 +272,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain(
type="tool_direct_result"
).base64_image(res.content[0].data)
).base64_image(resource.blob)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -269,17 +283,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
yield MessageChain().message("返回的数据类型不受支持。")
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
@@ -289,27 +292,18 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else:
# 不应该出现其他类型
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool, func_tool_args, _final_resp
)
except Exception as e:
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
self.run_context.event.clear_result()
except Exception as e:

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from typing import Awaitable, Callable, Literal, Any, Optional
from .mcp_client import MCPClient
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str | None = None
name: str
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
@@ -51,7 +51,7 @@ class ToolSet:
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] = None):
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
@@ -79,7 +79,13 @@ class ToolSet:
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
@@ -104,7 +110,7 @@ class ToolSet:
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
def get_func(self, name: str) -> FunctionTool | None:
"""Get all function tools."""
return self.get_tool(name)
@@ -125,7 +131,11 @@ class ToolSet:
},
}
if tool.parameters.get("properties") or not omit_empty_parameter_field:
if (
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
@@ -135,14 +145,14 @@ class ToolSet:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
"input_schema": input_schema,
}
result.append(tool_def)
return result
@@ -210,14 +220,15 @@ class ToolSet:
return result
tools = [
{
tools = []
for tool in self.tools:
d = {
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)
declarations = {}
if tools:

View File

@@ -9,3 +9,4 @@ class AstrAgentContext:
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
tool_call_timeout: int = 60 # Default tool call timeout in seconds

View File

@@ -6,7 +6,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.0.0"
VERSION = "4.3.5"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -56,20 +56,23 @@ DEFAULT_CONFIG = {
"wake_prefix": "",
"web_search": False,
"websearch_provider": "default",
"websearch_tavily_key": "",
"websearch_tavily_key": [],
"websearch_baidu_app_builder_key": "",
"web_search_link": False,
"display_reasoning_text": False,
"identifier": False,
"group_name_display": False,
"datetime_system_prompt": True,
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "",
"prompt_prefix": "{{prompt}}",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"streaming_segmented": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
"provider_stt_settings": {
"enable": False,
@@ -115,6 +118,15 @@ DEFAULT_CONFIG = {
"port": 6185,
},
"platform": [],
"platform_specific": {
# 平台特异配置:按平台分类,平台下按功能分组
"lark": {
"pre_ack_emoji": {"enable": False, "emojis": ["Typing"]},
},
"telegram": {
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
},
},
"wake_prefix": ["/"],
"log_level": "INFO",
"pip_install_arg": "",
@@ -197,6 +209,18 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6195,
},
"企业微信智能机器人": {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"enable": True,
"wecomaibot_init_respond_text": "💭 思考中...",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"token": "",
"encoding_aes_key": "",
"callback_server_host": "0.0.0.0",
"port": 6198,
},
"飞书(Lark)": {
"id": "lark",
"type": "lark",
@@ -235,6 +259,16 @@ CONFIG_METADATA_2 = {
"discord_guild_id_for_debug": "",
"discord_activity_name": "",
},
"Misskey": {
"id": "misskey",
"type": "misskey",
"enable": False,
"misskey_instance_url": "https://misskey.example",
"misskey_token": "",
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
},
"Slack": {
"id": "slack",
"type": "slack",
@@ -252,7 +286,7 @@ CONFIG_METADATA_2 = {
"type": "satori",
"enable": False,
"satori_api_base_url": "http://localhost:5140/satori/v1",
"satori_endpoint": "ws://127.0.0.1:5140/satori/v1/events",
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
"satori_token": "",
"satori_auto_reconnect": True,
"satori_heartbeat_interval": 10,
@@ -261,34 +295,34 @@ CONFIG_METADATA_2 = {
},
"items": {
"satori_api_base_url": {
"description": "Satori API Base URL",
"description": "Satori API 终结点",
"type": "string",
"hint": "The base URL for the Satori API.",
"hint": "Satori API 的基础地址。",
},
"satori_endpoint": {
"description": "Satori WebSocket Endpoint",
"description": "Satori WebSocket 终结点",
"type": "string",
"hint": "The WebSocket endpoint for Satori events.",
"hint": "Satori 事件的 WebSocket 端点。",
},
"satori_token": {
"description": "Satori Token",
"description": "Satori 令牌",
"type": "string",
"hint": "The token used for authenticating with the Satori API.",
"hint": "用于 Satori API 身份验证的令牌。",
},
"satori_auto_reconnect": {
"description": "Enable Auto Reconnect",
"description": "启用自动重连",
"type": "bool",
"hint": "Whether to automatically reconnect the WebSocket on disconnection.",
"hint": "断开连接时是否自动重新连接 WebSocket。",
},
"satori_heartbeat_interval": {
"description": "Satori Heartbeat Interval",
"description": "Satori 心跳间隔",
"type": "int",
"hint": "The interval (in seconds) for sending heartbeat messages.",
"hint": "发送心跳消息的间隔(秒)。",
},
"satori_reconnect_delay": {
"description": "Satori Reconnect Delay",
"description": "Satori 重连延迟",
"type": "int",
"hint": "The delay (in seconds) before attempting to reconnect.",
"hint": "尝试重新连接前的延迟时间(秒)。",
},
"slack_connection_mode": {
"description": "Slack Connection Mode",
@@ -336,6 +370,32 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
},
"misskey_instance_url": {
"description": "Misskey 实例 URL",
"type": "string",
"hint": "例如 https://misskey.example填写 Bot 账号所在的 Misskey 实例地址",
},
"misskey_token": {
"description": "Misskey Access Token",
"type": "string",
"hint": "连接服务设置生成的 API 鉴权访问令牌Access token",
},
"misskey_default_visibility": {
"description": "默认帖子可见性",
"type": "string",
"options": ["public", "home", "followers"],
"hint": "机器人发帖时的默认可见性设置。public公开home主页时间线followers仅关注者。",
},
"misskey_local_only": {
"description": "仅限本站(不参与联合)",
"type": "bool",
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
},
"misskey_enable_chat": {
"description": "启用聊天消息响应",
"type": "bool",
"hint": "启用后,机器人将会监听和响应私信聊天消息",
},
"telegram_command_register": {
"description": "Telegram 命令注册",
"type": "bool",
@@ -401,10 +461,25 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
},
"wecom_ai_bot_name": {
"description": "企业微信智能机器人的名字",
"type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。",
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。",
},
"wecomaibot_friend_message_welcome_text": {
"description": "企业微信智能机器人私聊欢迎语",
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
"hint": "请务必填,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
"hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
},
"discord_token": {
"description": "Discord Bot Token",
@@ -599,6 +674,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
},
@@ -613,6 +689,7 @@ CONFIG_METADATA_2 = {
"api_base": "",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"xAI": {
@@ -625,6 +702,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Anthropic": {
@@ -654,6 +732,7 @@ CONFIG_METADATA_2 = {
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"LM Studio": {
@@ -667,6 +746,7 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "llama-3.1-8b",
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini(OpenAI兼容)": {
@@ -682,6 +762,7 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash",
"temperature": 0.4,
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini": {
@@ -722,7 +803,8 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"modalities": ["text", "image", "tool_use"],
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
"302.AI": {
"id": "302ai",
@@ -734,6 +816,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://api.302.ai/v1",
"timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"硅基流动": {
@@ -749,6 +832,7 @@ CONFIG_METADATA_2 = {
"model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"PPIO派欧云": {
@@ -764,6 +848,22 @@ CONFIG_METADATA_2 = {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
},
"custom_extra_body": {},
},
"小马算力": {
"id": "tokenpony",
"provider": "tokenpony",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.tokenpony.cn/v1",
"timeout": 120,
"model_config": {
"model": "kimi-k2-instruct-0905",
"temperature": 0.7,
},
"custom_extra_body": {},
},
"优云智算": {
"id": "compshare",
@@ -777,6 +877,7 @@ CONFIG_METADATA_2 = {
"model_config": {
"model": "moonshotai/Kimi-K2-Instruct",
},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Kimi": {
@@ -789,6 +890,7 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"智谱 AI": {
@@ -820,6 +922,18 @@ CONFIG_METADATA_2 = {
"timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
},
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
@@ -847,6 +961,7 @@ CONFIG_METADATA_2 = {
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"FastGPT": {
@@ -858,6 +973,7 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.fastgpt.in/api/v1",
"timeout": 60,
"custom_extra_body": {},
},
"Whisper(API)": {
"id": "whisper",
@@ -969,6 +1085,7 @@ CONFIG_METADATA_2 = {
"timeout": "20",
},
"阿里云百炼 TTS(API)": {
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
"id": "dashscope_tts",
"provider": "dashscope",
"type": "dashscope_tts",
@@ -1102,6 +1219,12 @@ CONFIG_METADATA_2 = {
"render_type": "checkbox",
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
},
"custom_extra_body": {
"description": "自定义请求体参数",
"type": "dict",
"items": {},
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
},
"provider": {
"type": "string",
"invisible": True,
@@ -1342,11 +1465,7 @@ CONFIG_METADATA_2 = {
"description": "服务订阅密钥",
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
},
"dashscope_tts_voice": {
"description": "语音合成模型",
"type": "string",
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
},
"dashscope_tts_voice": {"description": "音色", "type": "string"},
"gm_resp_image_modal": {
"description": "启用图片模态",
"type": "bool",
@@ -1678,6 +1797,26 @@ CONFIG_METADATA_2 = {
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
"obvious": True,
},
"coze_api_key": {
"description": "Coze API Key",
"type": "string",
"hint": "Coze API 密钥,用于访问 Coze 服务。",
},
"bot_id": {
"description": "Bot ID",
"type": "string",
"hint": "Coze 机器人的 ID在 Coze 平台上创建机器人后获得。",
},
"coze_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
},
"auto_save_history": {
"description": "由 Coze 管理对话记录",
"type": "bool",
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
},
},
},
"provider_settings": {
@@ -1704,6 +1843,9 @@ CONFIG_METADATA_2 = {
"identifier": {
"type": "bool",
},
"group_name_display": {
"type": "bool",
},
"datetime_system_prompt": {
"type": "bool",
},
@@ -1732,6 +1874,10 @@ CONFIG_METADATA_2 = {
"description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
},
},
"provider_stt_settings": {
@@ -1883,17 +2029,33 @@ CONFIG_METADATA_3 = {
"_special": "select_provider",
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
},
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",
"hint": "STT 总开关。",
},
"provider_stt_settings.provider_id": {
"description": "语音转文本模型",
"description": "默认语音转文本模型",
"type": "string",
"hint": "留空代表不使用",
"hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型",
"_special": "select_provider_stt",
"condition": {
"provider_stt_settings.enable": True,
},
},
"provider_tts_settings.enable": {
"description": "启用文本转语音",
"type": "bool",
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
},
"provider_tts_settings.provider_id": {
"description": "文本转语音模型",
"description": "默认文本转语音模型",
"type": "string",
"hint": "留空代表不使用",
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型",
"_special": "select_provider_tts",
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": {
"description": "图片转述提示词",
@@ -1934,15 +2096,25 @@ CONFIG_METADATA_3 = {
"provider_settings.websearch_provider": {
"description": "网页搜索提供商",
"type": "string",
"options": ["default", "tavily"],
"options": ["default", "tavily", "baidu_ai_search"],
},
"provider_settings.websearch_tavily_key": {
"description": "Tavily API Key",
"type": "string",
"type": "list",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"condition": {
"provider_settings.websearch_provider": "tavily",
},
},
"provider_settings.websearch_baidu_app_builder_key": {
"description": "百度千帆智能云 APP Builder API Key",
"type": "string",
"hint": "参考https://console.bce.baidu.com/iam/#/iam/apikey/list",
"condition": {
"provider_settings.websearch_provider": "baidu_ai_search",
},
},
"provider_settings.web_search_link": {
"description": "显示来源引用",
"type": "bool",
@@ -1961,6 +2133,11 @@ CONFIG_METADATA_3 = {
"description": "用户识别",
"type": "bool",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",
@@ -1973,6 +2150,10 @@ CONFIG_METADATA_3 = {
"description": "工具调用轮数上限",
"type": "int",
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
"provider_settings.streaming_response": {
"description": "流式回复",
"type": "bool",
@@ -1994,12 +2175,14 @@ CONFIG_METADATA_3 = {
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
"hint": "例子: 如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
},
"provider_settings.prompt_prefix": {
"description": "额外前缀提示词",
"description": "用户提示词",
"type": "string",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
},
"provider_settings.dual_output": {
"provider_tts_settings.dual_output": {
"description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool",
},
@@ -2108,41 +2291,41 @@ CONFIG_METADATA_3 = {
"description": "内容安全",
"type": "object",
"items": {
"platform_settings.content_safety.also_use_in_response": {
"content_safety.also_use_in_response": {
"description": "同时检查模型的响应内容",
"type": "bool",
},
"platform_settings.content_safety.baidu_aip.enable": {
"content_safety.baidu_aip.enable": {
"description": "使用百度内容安全审核",
"type": "bool",
"hint": "您需要手动安装 baidu-aip 库。",
},
"platform_settings.content_safety.baidu_aip.app_id": {
"content_safety.baidu_aip.app_id": {
"description": "App ID",
"type": "string",
"condition": {
"platform_settings.content_safety.baidu_aip.enable": True,
"content_safety.baidu_aip.enable": True,
},
},
"platform_settings.content_safety.baidu_aip.api_key": {
"content_safety.baidu_aip.api_key": {
"description": "API Key",
"type": "string",
"condition": {
"platform_settings.content_safety.baidu_aip.enable": True,
"content_safety.baidu_aip.enable": True,
},
},
"platform_settings.content_safety.baidu_aip.secret_key": {
"content_safety.baidu_aip.secret_key": {
"description": "Secret Key",
"type": "string",
"condition": {
"platform_settings.content_safety.baidu_aip.enable": True,
"content_safety.baidu_aip.enable": True,
},
},
"platform_settings.content_safety.internal_keywords.enable": {
"content_safety.internal_keywords.enable": {
"description": "关键词检查",
"type": "bool",
},
"platform_settings.content_safety.internal_keywords.extra_keywords": {
"content_safety.internal_keywords.extra_keywords": {
"description": "额外关键词",
"type": "list",
"items": {"type": "string"},
@@ -2180,6 +2363,32 @@ CONFIG_METADATA_3 = {
"description": "用户权限不足时是否回复",
"type": "bool",
},
"platform_specific.lark.pre_ack_emoji.enable": {
"description": "[飞书] 启用预回应表情",
"type": "bool",
},
"platform_specific.lark.pre_ack_emoji.emojis": {
"description": "表情列表(飞书表情枚举名)",
"type": "list",
"items": {"type": "string"},
"hint": "表情枚举名参考https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce",
"condition": {
"platform_specific.lark.pre_ack_emoji.enable": True,
},
},
"platform_specific.telegram.pre_ack_emoji.enable": {
"description": "[Telegram] 启用预回应表情",
"type": "bool",
},
"platform_specific.telegram.pre_ack_emoji.emojis": {
"description": "表情列表Unicode",
"type": "list",
"items": {"type": "string"},
"hint": "Telegram 仅支持固定反应集合参考https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9",
"condition": {
"platform_specific.telegram.pre_ack_emoji.enable": True,
},
},
},
},
},

View File

@@ -87,17 +87,25 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
if f:
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid == conversation_id:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID

View File

@@ -154,12 +154,17 @@ class BaseDatabase(abc.ABC):
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def delete_conversations_by_user_id(self, user_id: str) -> None:
"""Delete all conversations for a specific user."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,
platform_id: str,
user_id: str,
content: list[dict],
content: dict,
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
@@ -282,3 +287,14 @@ class BaseDatabase(abc.ABC):
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation."""
# ...
@abc.abstractmethod
async def get_session_conversations(
self,
page: int = 1,
page_size: int = 20,
search_query: str | None = None,
platform: str | None = None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
...

View File

@@ -53,7 +53,7 @@ async def do_migration_v4(
await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移
await migration_preferences(db_helper,platform_id_map)
await migration_preferences(db_helper, platform_id_map)
# 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map)

View File

@@ -5,6 +5,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
@@ -42,4 +43,5 @@ class SharedPreferences:
self._data.clear()
self._save_preferences()
sp = SharedPreferences()

View File

@@ -4,6 +4,7 @@ from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass
class Conversation:
"""LLM 对话存储
@@ -76,7 +77,7 @@ PRAGMA encoding = 'UTF-8';
"""
class SQLiteDatabase():
class SQLiteDatabase:
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path

View File

@@ -75,7 +75,9 @@ class Persona(SQLModel, table=True):
__tablename__ = "personas"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
@@ -135,7 +137,9 @@ class PlatformMessageHistory(SQLModel, table=True):
__tablename__ = "platform_message_history"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
@@ -158,8 +162,8 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments"
inner_attachment_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
inner_attachment_id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
attachment_id: str = Field(
max_length=36,

View File

@@ -15,9 +15,8 @@ from astrbot.core.db.po import (
SQLModel,
)
from sqlalchemy import select, update, delete, text
from sqlmodel import select, update, delete, text, func, or_, desc, col
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
@@ -33,6 +32,12 @@ class SQLiteDatabase(BaseDatabase):
"""Initialize the database by creating tables if they do not exist."""
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
# ====
@@ -41,10 +46,10 @@ class SQLiteDatabase(BaseDatabase):
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime = None,
platform_id,
platform_type,
count=1,
timestamp=None,
) -> None:
"""Insert a new platform statistic record."""
async with self.get_db() as session:
@@ -75,7 +80,9 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
select(func.count(col(PlatformStat.platform_id))).select_from(
PlatformStat
)
)
count = result.scalar_one_or_none()
return count if count is not None else 0
@@ -95,7 +102,7 @@ class SQLiteDatabase(BaseDatabase):
"""),
{"start_time": start_time},
)
return result.scalars().all()
return list(result.scalars().all())
# ====
# Conversation Management
@@ -111,7 +118,7 @@ class SQLiteDatabase(BaseDatabase):
if platform_id:
query = query.where(ConversationV2.platform_id == platform_id)
# order by
query = query.order_by(ConversationV2.created_at.desc())
query = query.order_by(desc(ConversationV2.created_at))
result = await session.execute(query)
return result.scalars().all()
@@ -129,7 +136,7 @@ class SQLiteDatabase(BaseDatabase):
offset = (page - 1) * page_size
result = await session.execute(
select(ConversationV2)
.order_by(ConversationV2.created_at.desc())
.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
@@ -150,11 +157,26 @@ class SQLiteDatabase(BaseDatabase):
if platform_ids:
base_query = base_query.where(
ConversationV2.platform_id.in_(platform_ids)
col(ConversationV2.platform_id).in_(platform_ids)
)
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
base_query = base_query.where(
ConversationV2.title.ilike(f"%{search_query}%")
or_(
col(ConversationV2.title).ilike(f"%{search_query}%"),
col(ConversationV2.content).ilike(f"%{search_query}%"),
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
)
)
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
for msg_type in kwargs["message_types"]:
base_query = base_query.where(
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
)
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
base_query = base_query.where(
col(ConversationV2.platform_id).in_(kwargs["platforms"])
)
# Get total count matching the filters
@@ -165,7 +187,7 @@ class SQLiteDatabase(BaseDatabase):
# Get paginated results
offset = (page - 1) * page_size
result_query = (
base_query.order_by(ConversationV2.created_at.desc())
base_query.order_by(desc(ConversationV2.created_at))
.offset(offset)
.limit(page_size)
)
@@ -211,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
query = update(ConversationV2).where(
ConversationV2.conversation_id == cid
col(ConversationV2.conversation_id) == cid
)
values = {}
if title is not None:
@@ -231,9 +253,126 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
delete(ConversationV2).where(
col(ConversationV2.conversation_id) == cid
)
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
)
async def get_session_conversations(
self,
page=1,
page_size=20,
search_query=None,
platform=None,
) -> tuple[list[dict], int]:
"""Get paginated session conversations with joined conversation and persona details."""
async with self.get_db() as session:
session: AsyncSession
offset = (page - 1) * page_size
base_query = (
select(
col(Preference.scope_id).label("session_id"),
func.json_extract(Preference.value, "$.val").label(
"conversation_id"
), # type: ignore
col(ConversationV2.persona_id).label("persona_id"),
col(ConversationV2.title).label("title"),
col(Persona.persona_id).label("persona_name"),
)
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 搜索筛选
if search_query:
search_pattern = f"%{search_query}%"
base_query = base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
)
)
# 平台筛选
if platform:
platform_pattern = f"{platform}:%"
base_query = base_query.where(
col(Preference.scope_id).like(platform_pattern)
)
# 排序
base_query = base_query.order_by(Preference.scope_id)
# 分页结果
result_query = base_query.offset(offset).limit(page_size)
result = await session.execute(result_query)
rows = result.fetchall()
# 查询总数(应用相同的筛选条件)
count_base_query = (
select(func.count(col(Preference.scope_id)))
.select_from(Preference)
.outerjoin(
ConversationV2,
func.json_extract(Preference.value, "$.val")
== ConversationV2.conversation_id,
)
.outerjoin(
Persona, col(ConversationV2.persona_id) == Persona.persona_id
)
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
)
# 应用相同的搜索和平台筛选条件到计数查询
if search_query:
search_pattern = f"%{search_query}%"
count_base_query = count_base_query.where(
or_(
col(Preference.scope_id).ilike(search_pattern),
col(ConversationV2.title).ilike(search_pattern),
col(Persona.persona_id).ilike(search_pattern),
)
)
if platform:
platform_pattern = f"{platform}:%"
count_base_query = count_base_query.where(
col(Preference.scope_id).like(platform_pattern)
)
total_result = await session.execute(count_base_query)
total = total_result.scalar() or 0
sessions_data = [
{
"session_id": row.session_id,
"conversation_id": row.conversation_id,
"persona_id": row.persona_id,
"title": row.title,
"persona_name": row.persona_name,
}
for row in rows
]
return sessions_data, total
async def insert_platform_message_history(
self,
platform_id,
@@ -267,9 +406,9 @@ class SQLiteDatabase(BaseDatabase):
cutoff_time = now - timedelta(seconds=offset_sec)
await session.execute(
delete(PlatformMessageHistory).where(
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
PlatformMessageHistory.created_at < cutoff_time,
col(PlatformMessageHistory.platform_id) == platform_id,
col(PlatformMessageHistory.user_id) == user_id,
col(PlatformMessageHistory.created_at) < cutoff_time,
)
)
@@ -286,7 +425,7 @@ class SQLiteDatabase(BaseDatabase):
PlatformMessageHistory.platform_id == platform_id,
PlatformMessageHistory.user_id == user_id,
)
.order_by(PlatformMessageHistory.created_at.desc())
.order_by(desc(PlatformMessageHistory.created_at))
)
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
@@ -308,7 +447,7 @@ class SQLiteDatabase(BaseDatabase):
"""Get an attachment by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(Attachment.id == attachment_id)
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
result = await session.execute(query)
return result.scalar_one_or_none()
@@ -351,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(Persona.persona_id == persona_id)
query = update(Persona).where(col(Persona.persona_id) == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
@@ -371,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(Persona.persona_id == persona_id)
delete(Persona).where(col(Persona.persona_id) == persona_id)
)
async def insert_preference_or_update(self, scope, scope_id, key, value):
@@ -426,9 +565,9 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope,
Preference.scope_id == scope_id,
Preference.key == key,
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
col(Preference.key) == key,
)
)
await session.commit()
@@ -440,7 +579,8 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(Preference).where(
Preference.scope == scope, Preference.scope_id == scope_id
col(Preference.scope) == scope,
col(Preference.scope_id) == scope_id,
)
)
await session.commit()
@@ -467,7 +607,7 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat(
name=data.platform_id,
count=data.count,
timestamp=data.timestamp.timestamp(),
timestamp=int(data.timestamp.timestamp()),
)
)
return deprecated_stats
@@ -525,7 +665,7 @@ class SQLiteDatabase(BaseDatabase):
DeprecatedPlatformStat(
name=platform_id,
count=count,
timestamp=start_time.timestamp(),
timestamp=int(start_time.timestamp()),
)
)
return deprecated_stats

View File

@@ -1,3 +1,3 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]
__all__ = ["FaissVecDB"]

View File

@@ -113,7 +113,8 @@ class FaissVecDB(BaseVecDB):
reranked_results, key=lambda x: x.relevance_score, reverse=True
)
top_k_results = [
top_k_results[reranked_result.index] for reranked_result in reranked_results
top_k_results[reranked_result.index]
for reranked_result in reranked_results
]
return top_k_results

View File

@@ -22,6 +22,7 @@ class InitialLoader:
self.db = db
self.logger = logger
self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
@@ -35,8 +36,10 @@ class InitialLoader:
core_task = core_lifecycle.start()
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
)
task = asyncio.gather(
core_task, self.dashboard_server.run()

View File

@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
self.strategy_selector = StrategySelector(config)
async def process(
self, event: AstrMessageEvent, check_text: str = None
self, event: AstrMessageEvent, check_text: str | None = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()

View File

@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
def check(self, content: str):
def check(self, content: str) -> tuple[bool, str]:
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
return False, ""

View File

@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
def check(self, content: str) -> bool:
def check(self, content: str) -> tuple[bool, str]:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"

View File

@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
handler: T.Callable[..., T.Awaitable[T.Any]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -36,6 +36,9 @@ async def call_handler(
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if not ready_to_call:
return
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
@@ -77,7 +80,7 @@ async def call_event_hook(
Returns:
bool: 如果事件被终止,返回 True
# """
#"""
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, plugins_name=event.plugins_name
)
@@ -94,5 +97,6 @@ async def call_event_hook(
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return True
return event.is_stopped()

View File

@@ -1,5 +1,6 @@
import traceback
import asyncio
import random
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
@@ -22,6 +23,26 @@ class PreProcessStage(Stage):
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""在处理事件之前的预处理"""
# 平台特异配置platform_specific.<platform>.pre_ack_emoji
supported = {"telegram", "lark"}
platform = event.get_platform_name()
cfg = (
self.config.get("platform_specific", {})
.get(platform, {})
.get("pre_ack_emoji", {})
) or {}
emojis = cfg.get("emojis") or []
if (
cfg.get("enable", False)
and platform in supported
and emojis
and event.is_at_or_wake_command
):
try:
await event.react(random.choice(emojis))
except Exception as e:
logger.warning(f"{platform} 预回应表情发送失败: {e}")
# 路径映射
if mappings := self.platform_settings.get("path_mapping", []):
# 支持 RecordImage 消息段的路径映射。
@@ -46,6 +67,9 @@ class PreProcessStage(Stage):
ctx = self.plugin_manager.context
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
)
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):

View File

@@ -6,7 +6,9 @@ import asyncio
import copy
import json
import traceback
from datetime import timedelta
from typing import AsyncGenerator, Union
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
@@ -133,6 +135,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
@@ -148,7 +159,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
@@ -175,21 +186,33 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
handler=awaitable,
**tool_args,
)
async for resp in wrapper:
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
# async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
)
except StopAsyncIteration:
break
@classmethod
async def _execute_mcp(
@@ -200,9 +223,16 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(
session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
read_timeout_seconds=timedelta(
seconds=run_context.context.tool_call_timeout
),
)
if not res:
return
@@ -271,19 +301,12 @@ async def run_agent(
except Exception as e:
logger.error(traceback.format_exc())
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
class LLMRequestSubStage(Stage):
@@ -300,6 +323,7 @@ class LLMRequestSubStage(Stage):
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
@@ -325,7 +349,7 @@ class LLMRequestSubStage(Stage):
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent):
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
@@ -337,6 +361,8 @@ class LLMRequestSubStage(Stage):
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation
async def process(
@@ -444,7 +470,10 @@ class LLMRequestSubStage(Stage):
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
@@ -461,6 +490,7 @@ class LLMRequestSubStage(Stage):
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
tool_call_timeout=self.tool_call_timeout,
)
await agent_runner.reset(
provider=provider,
@@ -505,6 +535,14 @@ class LLMRequestSubStage(Stage):
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -517,7 +555,23 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:]
if not latest_pair:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",

View File

@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
try:
wrapper = call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()

View File

@@ -1,17 +1,15 @@
import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from ..context import PipelineContext, call_event_hook
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import BaseMessageComponent, ComponentType
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@@ -114,6 +112,43 @@ class RespondStage(Stage):
# 如果所有组件都为空
return True
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
"""检查是否需要分段回复"""
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
return False
if event.get_platform_name() in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
return False
return True
def _extract_comp(
self,
raw_chain: list[BaseMessageComponent],
extract_types: set[ComponentType],
modify_raw_chain: bool = True,
):
extracted = []
if modify_raw_chain:
remaining = []
for comp in raw_chain:
if comp.type in extract_types:
extracted.append(comp)
else:
remaining.append(comp)
raw_chain[:] = remaining
else:
extracted = [comp for comp in raw_chain if comp.type in extract_types]
return extracted
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -123,7 +158,14 @@ class RespondStage(Stage):
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
logger.info(
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
if result.result_content_type == ResultContentType.STREAMING_RESULT:
if result.async_stream is None:
logger.warning("async_stream 为空,跳过发送。")
return
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
@@ -148,87 +190,81 @@ class RespondStage(Stage):
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
non_record_comps = [
c for c in result.chain if not isinstance(c, Comp.Record)
# 将 Plain 为空的消息段移除
result.chain = [
comp
for comp in result.chain
if not (
isinstance(comp, Comp.Plain)
and (not comp.text or not comp.text.strip())
)
]
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
# 发送消息链
# Record 需要强制单独发送
need_separately = {ComponentType.Record}
if self.is_seg_reply_required(event):
header_comps = self._extract_comp(
result.chain,
{ComponentType.Reply, ComponentType.At},
modify_raw_chain=True,
)
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
if not result.chain or len(result.chain) == 0:
# may fix #2670
logger.warning(
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
)
return
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
if comp.type in need_separately:
await event.send(MessageChain([comp]))
else:
await event.send(MessageChain([*header_comps, comp]))
header_comps.clear()
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
logger.error(
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
exc_info=True,
)
else:
for rcomp in record_comps:
if all(
comp.type in {ComponentType.Reply, ComponentType.At}
for comp in result.chain
):
# may fix #2670
logger.warning(
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
)
return
sep_comps = self._extract_comp(
result.chain,
need_separately,
modify_raw_chain=True,
)
for comp in sep_comps:
chain = MessageChain([comp])
try:
await event.send(MessageChain([rcomp]))
await event.send(chain)
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
chain = MessageChain(result.chain)
if result.chain and len(result.chain) > 0:
try:
await event.send(chain)
except Exception as e:
logger.error(
f"发送消息链失败: chain = {chain}, error = {e}",
exc_info=True,
)
try:
await event.send(MessageChain(non_record_comps))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
return
event.clear_result()

View File

@@ -183,56 +183,60 @@ class ResultDecorateStage(Stage):
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
and SessionServiceManager.should_process_tts_request(event)
):
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
if not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
)
else:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
)
new_chain.append(comp)
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
else:
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
result.chain = new_chain
# 文本转图片
elif (
@@ -275,7 +279,6 @@ class ResultDecorateStage(Stage):
result.chain = [Image.fromFileSystem(url)]
# 触发转发消息
has_forwarded = False
if event.get_platform_name() == "aiocqhttp":
word_cnt = 0
for comp in result.chain:
@@ -286,9 +289,9 @@ class ResultDecorateStage(Stage):
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
)
result.chain = [node]
has_forwarded = True
if not has_forwarded:
has_plain = any(isinstance(item, Plain) for item in result.chain)
if has_plain:
# at 回复
if (
self.reply_with_mention

View File

@@ -74,7 +74,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() == "webchat":
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
self.ctx = ctx
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
async def process(
self, event: AstrMessageEvent
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
conv_id = await self.conv_mgr.get_curr_conversation_id(
event.unified_msg_origin
)
if not conv_id:
await self.conv_mgr.new_conversation(
event.unified_msg_origin, platform_id=event.get_platform_id()
)
event.stop_event()

View File

@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
activated_handlers.append(handler)
if "parsed_params" in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
is_group_cmd_handler = any(
isinstance(f, CommandGroupFilter) for f in handler.event_filters
)
if not is_group_cmd_handler:
activated_handlers.append(handler)
if "parsed_params" in event.get_extra(default={}):
handlers_parsed_params[handler.handler_full_name] = (
event.get_extra("parsed_params")
)
event._extras.pop("parsed_params", None)

View File

@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -24,7 +24,9 @@ from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
from .message_session import MessageSession, MessageSesion # noqa
_VT = TypeVar("_VT")
class AstrMessageEvent(abc.ABC):
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
"""是否唤醒(是否通过 WakingStage)"""
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self._extras: dict[str, Any] = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
message_type=message_obj.type,
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self._result: MessageEventResult = None
self._result: MessageEventResult | None = None
"""消息事件的结果"""
self._has_send_oper = False
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(self, key=None):
def get_extra(
self, key: str | None = None, default: _VT = None
) -> dict[str, Any] | _VT:
"""
获取额外的信息。
"""
if key is None:
return self._extras
return self._extras.get(key, None)
return self._extras.get(key, default)
def clear_extra(self):
"""
@@ -412,6 +416,16 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def react(self, emoji: str):
"""
对消息添加表情回应。
默认实现为发送一条包含该表情的消息。
注意:此实现并不一定符合所有平台的原生“表情回应”行为。
如需支持平台原生的消息反应功能,请在对应平台的子类中重写本方法。
"""
await self.send(MessageChain([Plain(emoji)]))
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息返回当前群聊的数据。

View File

@@ -55,7 +55,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group_id: str = "" # 群组id如果为私聊则为空
group: Group # 群组
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -64,6 +64,28 @@ class AstrBotMessage:
def __init__(self) -> None:
self.timestamp = int(time.time())
self.group = None
def __str__(self) -> str:
return str(self.__dict__)
@property
def group_id(self) -> str:
"""
向后兼容的 group_id 属性
群组id如果为私聊则为空
"""
if self.group:
return self.group.group_id
return ""
@group_id.setter
def group_id(self, value: str):
"""设置 group_id"""
if value:
if self.group:
self.group.group_id = value
else:
self.group = Group(group_id=value)
else:
self.group = None

View File

@@ -6,6 +6,7 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType
from .sources.webchat.webchat_adapter import WebChatAdapter
@@ -66,27 +67,43 @@ class PlatformManager:
WeChatPadProAdapter, # noqa: F401
)
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
from .sources.lark.lark_adapter import (
LarkPlatformAdapter, # noqa: F401
)
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram":
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
from .sources.telegram.tg_adapter import (
TelegramPlatformAdapter, # noqa: F401
)
case "wecom":
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
from .sources.wecom.wecom_adapter import (
WecomPlatformAdapter, # noqa: F401
)
case "wecom_ai_bot":
from .sources.wecom_ai_bot.wecomai_adapter import (
WecomAIBotAdapter, # noqa: F401
)
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa
WeixinOfficialAccountPlatformAdapter, # noqa: F401
)
case "discord":
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "misskey":
from .sources.misskey.misskey_adapter import (
MisskeyPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
case "satori":
from .sources.satori.satori_adapter import SatoriPlatformAdapter # noqa: F401
from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
@@ -115,6 +132,17 @@ class PlatformManager:
)
)
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnPlatformLoadedEvent
)
for handler in handlers:
try:
logger.info(
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(self, task: asyncio.Task):
try:

View File

@@ -14,3 +14,5 @@ class PlatformMetadata:
"""平台的默认配置模板"""
adapter_display_name: str = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""

View File

@@ -13,10 +13,12 @@ def register_platform_adapter(
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None,
logo_path: str = None,
):
"""用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
"""
def decorator(cls):
@@ -39,6 +41,7 @@ def register_platform_adapter(
description=desc,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls

View File

@@ -182,11 +182,13 @@ class AiocqhttpAdapter(Platform):
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
str(event.sender["user_id"]), event.sender["nickname"]
str(event.sender["user_id"]),
event.sender.get("card") or event.sender.get("nickname", "N/A"),
)
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
@@ -321,7 +323,9 @@ class AiocqhttpAdapter(Platform):
user_id=int(m["data"]["qq"]),
no_cache=False,
)
nickname = at_info.get("nick", "") or at_info.get("nickname", "")
nickname = at_info.get("nick", "") or at_info.get(
"nickname", ""
)
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
abm.message.append(

View File

@@ -54,9 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
logger.debug(f"send image: {ret}")
except Exception as e:
logger.error(f"钉钉图片处理失败: {e}")
logger.warning(f"跳过图片发送: {image_path}")
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
continue
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)

View File

@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
await self.on_ready_once_callback()
except Exception as e:
logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
)
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type

View File

@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
self.url = url
self.disabled = disabled
class DiscordReference(BaseMessageComponent):
"""Discord引用组件"""
type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str):
self.message_id = message_id
self.channel_id = channel_id
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)

View File

@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
# 解析消息链为 Discord 所需的对象
try:
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
(
content,
files,
view,
embeds,
reference_message_id,
) = await self._parse_to_discord(message)
except Exception as e:
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
return
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
files.append(
discord.File(BytesIO(file_bytes),
filename=i.name)
discord.File(BytesIO(file_bytes), filename=i.name)
)
else:
logger.warning(

View File

@@ -107,6 +107,22 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def react(self, emoji: str):
request = (
CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id)
.request_body(
CreateMessageReactionRequestBody.builder()
.reaction_type(Emoji.builder().emoji_type(emoji).build())
.build()
)
.build()
)
response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
return None
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:

View File

@@ -0,0 +1,391 @@
import asyncio
import json
from typing import Dict, Any, Optional, Awaitable
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import (
AstrBotMessage,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
import astrbot.api.message_components as Comp
from .misskey_api import MisskeyAPI
from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import (
serialize_message_chain,
resolve_message_visibility,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
process_files,
extract_sender_info,
create_base_message,
process_at_mention,
cache_user_info,
cache_room_info,
)
@register_platform_adapter("misskey", "Misskey 平台适配器")
class MisskeyPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config or {}
self.settings = platform_settings or {}
self.instance_url = self.config.get("misskey_instance_url", "")
self.access_token = self.config.get("misskey_token", "")
self.max_message_length = self.config.get("max_message_length", 3000)
self.default_visibility = self.config.get(
"misskey_default_visibility", "public"
)
self.local_only = self.config.get("misskey_local_only", False)
self.enable_chat = self.config.get("misskey_enable_chat", True)
self.unique_session = platform_settings["unique_session"]
self.api: Optional[MisskeyAPI] = None
self._running = False
self.client_self_id = ""
self._bot_username = ""
self._user_cache = {}
def meta(self) -> PlatformMetadata:
default_config = {
"misskey_instance_url": "",
"misskey_token": "",
"max_message_length": 3000,
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
}
default_config.update(self.config)
return PlatformMetadata(
name="misskey",
description="Misskey 平台适配器",
id=self.config.get("id", "misskey"),
default_config_tmpl=default_config,
)
async def run(self):
if not self.instance_url or not self.access_token:
logger.error("[Misskey] 配置不完整,无法启动")
return
self.api = MisskeyAPI(self.instance_url, self.access_token)
self._running = True
try:
user_info = await self.api.get_current_user()
self.client_self_id = str(user_info.get("id", ""))
self._bot_username = user_info.get("username", "")
logger.info(
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
)
except Exception as e:
logger.error(f"[Misskey] 获取用户信息失败: {e}")
self._running = False
return
await self._start_websocket_connection()
async def _start_websocket_connection(self):
backoff_delay = 1.0
max_backoff = 300.0
backoff_multiplier = 1.5
connection_attempts = 0
while self._running:
try:
connection_attempts += 1
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
break
streaming = self.api.get_streaming_client()
streaming.add_message_handler("notification", self._handle_notification)
if self.enable_chat:
streaming.add_message_handler(
"newChatMessage", self._handle_chat_message
)
streaming.add_message_handler("_debug", self._debug_handler)
if await streaming.connect():
logger.info(
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
)
connection_attempts = 0 # 重置计数器
await streaming.subscribe_channel("main")
if self.enable_chat:
await streaming.subscribe_channel("messaging")
await streaming.subscribe_channel("messagingIndex")
logger.info("[Misskey] 聊天频道已订阅")
backoff_delay = 1.0 # 重置延迟
await streaming.listen()
else:
logger.error(
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
)
except Exception as e:
logger.error(
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
)
if self._running:
logger.info(
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
)
await asyncio.sleep(backoff_delay)
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
async def _handle_notification(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
notification_type = data.get("type")
if notification_type in ["mention", "reply", "quote"]:
note = data.get("note")
if note and self._is_bot_mentioned(note):
logger.info(
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
)
message = await self.convert_message(note)
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理通知失败: {e}")
async def _handle_chat_message(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
sender_id = str(
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
)
if sender_id == self.client_self_id:
return
room_id = data.get("toRoomId")
if room_id:
raw_text = data.get("text", "")
logger.debug(
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
)
message = await self.convert_room_message(data)
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
else:
message = await self.convert_chat_message(data)
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
event = MisskeyPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
async def _debug_handler(self, data: Dict[str, Any]):
logger.debug(
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
text = note.get("text", "")
if not text:
return False
mentions = note.get("mentions", [])
if self._bot_username and f"@{self._bot_username}" in text:
return True
if self.client_self_id in [str(uid) for uid in mentions]:
return True
reply = note.get("reply")
if reply and isinstance(reply, dict):
reply_user_id = str(reply.get("user", {}).get("id", ""))
if reply_user_id == self.client_self_id:
return bool(self._bot_username and f"@{self._bot_username}" in text)
return False
async def send_by_session(
self, session: MessageSession, message_chain: MessageChain
) -> Awaitable[Any]:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
try:
session_id = session.session_id
text, has_at_user = serialize_message_chain(message_chain.chain)
if not has_at_user and session_id:
user_info = self._user_cache.get(session_id)
text = add_at_mention_if_needed(text, user_info, has_at_user)
if not text or not text.strip():
logger.warning("[Misskey] 消息内容为空,跳过发送")
return await super().send_by_session(session, message_chain)
if len(text) > self.max_message_length:
text = text[: self.max_message_length] + "..."
if session_id and is_valid_user_session_id(session_id):
from .misskey_utils import extract_user_id_from_session_id
user_id = extract_user_id_from_session_id(session_id)
await self.api.send_message(user_id, text)
elif session_id and is_valid_room_session_id(session_id):
from .misskey_utils import extract_room_id_from_session_id
room_id = extract_room_id_from_session_id(session_id)
await self.api.send_room_message(room_id, text)
else:
visibility, visible_user_ids = resolve_message_visibility(
user_id=session_id,
user_cache=self._user_cache,
self_id=self.client_self_id,
default_visibility=self.default_visibility,
)
await self.api.create_note(
text,
visibility=visibility,
visible_user_ids=visible_user_ids,
local_only=self.local_only,
)
except Exception as e:
logger.error(f"[Misskey] 发送消息失败: {e}")
return await super().send_by_session(session, message_chain)
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=False)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
message_parts = []
raw_text = raw_data.get("text", "")
if raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=True,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
)
raw_text = raw_data.get("text", "")
if raw_text:
message.message.append(Comp.Plain(raw_text))
files = raw_data.get("files", [])
process_files(message, files, include_text_parts=False)
message.message_str = raw_text if raw_text else ""
return message
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
sender_info = extract_sender_info(raw_data, is_chat=True)
room_id = raw_data.get("toRoomId", "")
message = create_base_message(
raw_data,
sender_info,
self.client_self_id,
is_chat=False,
room_id=room_id,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
)
cache_room_info(self._user_cache, raw_data, self.client_self_id)
raw_text = raw_data.get("text", "")
message_parts = []
if raw_text:
if self._bot_username and f"@{self._bot_username}" in raw_text:
text_parts, processed_text = process_at_mention(
message, raw_text, self._bot_username, self.client_self_id
)
message_parts.extend(text_parts)
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
files = raw_data.get("files", [])
file_parts = process_files(message, files)
message_parts.extend(file_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
else ""
)
return message
async def terminate(self):
self._running = False
if self.api:
await self.api.close()
def get_client(self) -> Any:
return self.api

View File

@@ -0,0 +1,404 @@
import json
from typing import Any, Optional, Dict, List, Callable, Awaitable
import uuid
try:
import aiohttp
import websockets
except ImportError as e:
raise ImportError(
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
) from e
from astrbot.api import logger
# Constants
API_MAX_RETRIES = 3
HTTP_OK = 200
class APIError(Exception):
"""Misskey API 基础异常"""
pass
class APIConnectionError(APIError):
"""网络连接异常"""
pass
class APIRateLimitError(APIError):
"""API 频率限制异常"""
pass
class AuthenticationError(APIError):
"""认证失败异常"""
pass
class WebSocketError(APIError):
"""WebSocket 连接异常"""
pass
class StreamingClient:
def __init__(self, instance_url: str, access_token: str):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self.websocket: Optional[Any] = None
self.is_connected = False
self.message_handlers: Dict[str, Callable] = {}
self.channels: Dict[str, str] = {}
self._running = False
self._last_pong = None
async def connect(self) -> bool:
try:
ws_url = self.instance_url.replace("https://", "wss://").replace(
"http://", "ws://"
)
ws_url += f"/streaming?i={self.access_token}"
self.websocket = await websockets.connect(
ws_url, ping_interval=30, ping_timeout=10
)
self.is_connected = True
self._running = True
logger.info("[Misskey WebSocket] 已连接")
return True
except Exception as e:
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
self.is_connected = False
return False
async def disconnect(self):
self._running = False
if self.websocket:
await self.websocket.close()
self.websocket = None
self.is_connected = False
logger.info("[Misskey WebSocket] 连接已断开")
async def subscribe_channel(
self, channel_type: str, params: Optional[Dict] = None
) -> str:
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
channel_id = str(uuid.uuid4())
message = {
"type": "connect",
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
}
await self.websocket.send(json.dumps(message))
self.channels[channel_id] = channel_type
return channel_id
async def unsubscribe_channel(self, channel_id: str):
if (
not self.is_connected
or not self.websocket
or channel_id not in self.channels
):
return
message = {"type": "disconnect", "body": {"id": channel_id}}
await self.websocket.send(json.dumps(message))
del self.channels[channel_id]
def add_message_handler(
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
):
self.message_handlers[event_type] = handler
async def listen(self):
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
try:
async for message in self.websocket:
if not self._running:
break
try:
data = json.loads(message)
await self._handle_message(data)
except json.JSONDecodeError as e:
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
except Exception as e:
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
self.is_connected = False
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
)
self.is_connected = False
except websockets.exceptions.InvalidHandshake as e:
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
self.is_connected = False
except Exception as e:
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
self.is_connected = False
async def _handle_message(self, data: Dict[str, Any]):
message_type = data.get("type")
body = data.get("body", {})
logger.debug(
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
)
if message_type == "channel":
channel_id = body.get("id")
event_type = body.get("type")
event_body = body.get("body", {})
logger.debug(
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
)
if channel_id in self.channels:
channel_type = self.channels[channel_id]
handler_key = f"{channel_type}:{event_type}"
if handler_key in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
await self.message_handlers[handler_key](event_body)
elif event_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
await self.message_handlers[event_type](event_body)
else:
logger.debug(
f"[Misskey WebSocket] 未找到处理器: {handler_key}{event_type}"
)
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](
{
"type": event_type,
"body": event_body,
"channel": channel_type,
}
)
elif message_type in self.message_handlers:
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
await self.message_handlers[message_type](body)
else:
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
if "_debug" in self.message_handlers:
await self.message_handlers["_debug"](data)
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
def decorator(func):
async def wrapper(*args, **kwargs):
last_exc = None
for _ in range(max_retries):
try:
return await func(*args, **kwargs)
except retryable_exceptions as e:
last_exc = e
continue
if last_exc:
raise last_exc
return wrapper
return decorator
class MisskeyAPI:
def __init__(self, instance_url: str, access_token: str):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self._session: Optional[aiohttp.ClientSession] = None
self.streaming: Optional[StreamingClient] = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return False
async def close(self) -> None:
if self.streaming:
await self.streaming.disconnect()
self.streaming = None
if self._session:
await self._session.close()
self._session = None
logger.debug("[Misskey API] 客户端已关闭")
def get_streaming_client(self) -> StreamingClient:
if not self.streaming:
self.streaming = StreamingClient(self.instance_url, self.access_token)
return self.streaming
@property
def session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
headers = {"Authorization": f"Bearer {self.access_token}"}
self._session = aiohttp.ClientSession(headers=headers)
return self._session
def _handle_response_status(self, status: int, endpoint: str):
"""处理 HTTP 响应状态码"""
if status == 400:
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
raise APIError(f"Bad request for {endpoint}")
elif status in (401, 403):
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
raise AuthenticationError(f"Authentication failed for {endpoint}")
elif status == 429:
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
else:
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
raise APIConnectionError(f"HTTP {status} for {endpoint}")
async def _process_response(
self, response: aiohttp.ClientResponse, endpoint: str
) -> Any:
"""处理 API 响应"""
if response.status == HTTP_OK:
try:
result = await response.json()
if endpoint == "i/notifications":
notifications_data = (
result
if isinstance(result, list)
else result.get("notifications", [])
if isinstance(result, dict)
else []
)
if notifications_data:
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
else:
logger.debug(f"API 请求成功: {endpoint}")
return result
except json.JSONDecodeError as e:
logger.error(f"响应不是有效的 JSON 格式: {e}")
raise APIConnectionError("Invalid JSON response") from e
else:
try:
error_text = await response.text()
logger.error(
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
)
except Exception:
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
self._handle_response_status(response.status, endpoint)
raise APIConnectionError(f"Request failed for {endpoint}")
@retry_async(
max_retries=API_MAX_RETRIES,
retryable_exceptions=(APIConnectionError, APIRateLimitError),
)
async def _make_request(
self, endpoint: str, data: Optional[Dict[str, Any]] = None
) -> Any:
url = f"{self.instance_url}/api/{endpoint}"
payload = {"i": self.access_token}
if data:
payload.update(data)
try:
async with self.session.post(url, json=payload) as response:
return await self._process_response(response, endpoint)
except aiohttp.ClientError as e:
logger.error(f"HTTP 请求错误: {e}")
raise APIConnectionError(f"HTTP request failed: {e}") from e
async def create_note(
self,
text: str,
visibility: str = "public",
reply_id: Optional[str] = None,
visible_user_ids: Optional[List[str]] = None,
local_only: bool = False,
) -> Dict[str, Any]:
"""创建新贴文"""
data: Dict[str, Any] = {
"text": text,
"visibility": visibility,
"localOnly": local_only,
}
if reply_id:
data["replyId"] = reply_id
if visible_user_ids and visibility == "specified":
data["visibleUserIds"] = visible_user_ids
result = await self._make_request("notes/create", data)
note_id = result.get("createdNote", {}).get("id", "unknown")
logger.debug(f"发帖成功note_id: {note_id}")
return result
async def get_current_user(self) -> Dict[str, Any]:
"""获取当前用户信息"""
return await self._make_request("i", {})
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
"""发送聊天消息"""
result = await self._make_request(
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
)
message_id = result.get("id", "unknown")
logger.debug(f"聊天发送成功message_id: {message_id}")
return result
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
"""发送房间消息"""
result = await self._make_request(
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
)
message_id = result.get("id", "unknown")
logger.debug(f"房间消息发送成功message_id: {message_id}")
return result
async def get_messages(
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取聊天消息历史"""
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
if since_id:
data["sinceId"] = since_id
result = await self._make_request("chat/messages/user-timeline", data)
if isinstance(result, list):
return result
else:
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
return []
async def get_mentions(
self, limit: int = 10, since_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""获取提及通知"""
data: Dict[str, Any] = {"limit": limit}
if since_id:
data["sinceId"] = since_id
data["includeTypes"] = ["mention", "reply", "quote"]
result = await self._make_request("i/notifications", data)
if isinstance(result, list):
return result
elif isinstance(result, dict) and "notifications" in result:
return result["notifications"]
else:
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
return []

View File

@@ -0,0 +1,123 @@
import asyncio
import re
from typing import AsyncGenerator
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
from astrbot.api.message_components import Plain
from .misskey_utils import (
serialize_message_chain,
resolve_visibility_from_raw_message,
is_valid_user_session_id,
is_valid_room_session_id,
add_at_mention_if_needed,
extract_user_id_from_session_id,
extract_room_id_from_session_id,
)
class MisskeyPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
def _is_system_command(self, message_str: str) -> bool:
"""检测是否为系统指令"""
if not message_str or not message_str.strip():
return False
system_prefixes = ["/", "!", "#", ".", "^"]
message_trimmed = message_str.strip()
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
async def send(self, message: MessageChain):
content, has_at = serialize_message_chain(message.chain)
if not content:
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
return
try:
original_message_id = getattr(self.message_obj, "message_id", None)
raw_message = getattr(self.message_obj, "raw_message", {})
if raw_message and not has_at:
user_data = raw_message.get("user", {})
user_info = {
"username": user_data.get("username", ""),
"nickname": user_data.get("name", user_data.get("username", "")),
}
content = add_at_mention_if_needed(content, user_info, has_at)
# 根据会话类型选择发送方式
if hasattr(self.client, "send_message") and is_valid_user_session_id(
self.session_id
):
user_id = extract_user_id_from_session_id(self.session_id)
await self.client.send_message(user_id, content)
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
self.session_id
):
room_id = extract_room_id_from_session_id(self.session_id)
await self.client.send_room_message(room_id, content)
elif original_message_id and hasattr(self.client, "create_note"):
visibility, visible_user_ids = resolve_visibility_from_raw_message(
raw_message
)
await self.client.create_note(
content,
reply_id=original_message_id,
visibility=visibility,
visible_user_ids=visible_user_ids,
)
elif hasattr(self.client, "create_note"):
logger.debug("[MisskeyEvent] 创建新帖子")
await self.client.create_note(content)
await super().send(message)
except Exception as e:
logger.error(f"[MisskeyEvent] 发送失败: {e}")
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,327 @@
"""Misskey 平台适配器通用工具函数"""
from typing import Dict, Any, List, Tuple, Optional, Union
import astrbot.api.message_components as Comp
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
"""将消息链序列化为文本字符串"""
text_parts = []
has_at = False
def process_component(component):
nonlocal has_at
if isinstance(component, Comp.Plain):
return component.text
elif isinstance(component, Comp.File):
file_name = getattr(component, "name", "文件")
return f"[文件: {file_name}]"
elif isinstance(component, Comp.At):
has_at = True
return f"@{component.qq}"
elif hasattr(component, "text"):
text = getattr(component, "text", "")
if "@" in text:
has_at = True
return text
else:
return str(component)
for component in chain:
if isinstance(component, Comp.Node) and component.content:
for node_comp in component.content:
result = process_component(node_comp)
if result:
text_parts.append(result)
else:
result = process_component(component)
if result:
text_parts.append(result)
return "".join(text_parts), has_at
def resolve_message_visibility(
user_id: Optional[str],
user_cache: Dict[str, Any],
self_id: Optional[str],
default_visibility: str = "public",
) -> Tuple[str, Optional[List[str]]]:
"""解析 Misskey 消息的可见性设置"""
visibility = default_visibility
visible_user_ids = None
if user_id and user_cache:
user_info = user_cache.get(user_id)
if user_info:
original_visibility = user_info.get("visibility", default_visibility)
if original_visibility == "specified":
visibility = "specified"
original_visible_users = user_info.get("visible_user_ids", [])
users_to_include = [user_id]
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
def resolve_visibility_from_raw_message(
raw_message: Dict[str, Any], self_id: Optional[str] = None
) -> Tuple[str, Optional[List[str]]]:
"""从原始消息数据中解析可见性设置"""
visibility = "public"
visible_user_ids = None
if not raw_message:
return visibility, visible_user_ids
original_visibility = raw_message.get("visibility", "public")
if original_visibility == "specified":
visibility = "specified"
original_visible_users = raw_message.get("visibleUserIds", [])
sender_id = raw_message.get("userId", "")
users_to_include = []
if sender_id:
users_to_include.append(sender_id)
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "chat"
and bool(parts[1])
and parts[1] != "unknown"
)
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "room"
and bool(parts[1])
and parts[1] != "unknown"
)
def extract_user_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取用户 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2:
return parts[1]
return session_id
def extract_room_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取房间 ID"""
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2 and parts[0] == "room":
return parts[1]
return session_id
def add_at_mention_if_needed(
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
) -> str:
"""如果需要且没有@用户,则添加@用户"""
if has_at or not user_info:
return text
username = user_info.get("username")
nickname = user_info.get("nickname")
if username:
mention = f"@{username}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
elif nickname:
mention = f"@{nickname}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
return text
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
"""创建文件组件和描述文本"""
file_url = file_info.get("url", "")
file_name = file_info.get("name", "未知文件")
file_type = file_info.get("type", "")
if file_type.startswith("image/"):
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
elif file_type.startswith("audio/"):
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
elif file_type.startswith("video/"):
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
else:
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
def process_files(
message: AstrBotMessage, files: list, include_text_parts: bool = True
) -> list:
"""处理文件列表,添加到消息组件中并返回文本描述"""
file_parts = []
for file_info in files:
component, part_text = create_file_component(file_info)
message.message.append(component)
if include_text_parts:
file_parts.append(part_text)
return file_parts
def extract_sender_info(
raw_data: Dict[str, Any], is_chat: bool = False
) -> Dict[str, Any]:
"""提取发送者信息"""
if is_chat:
sender = raw_data.get("fromUser", {})
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
else:
sender = raw_data.get("user", {})
sender_id = str(sender.get("id", ""))
return {
"sender": sender,
"sender_id": sender_id,
"nickname": sender.get("name", sender.get("username", "")),
"username": sender.get("username", ""),
}
def create_base_message(
raw_data: Dict[str, Any],
sender_info: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
room_id: Optional[str] = None,
unique_session: bool = False,
) -> AstrBotMessage:
"""创建基础消息对象"""
message = AstrBotMessage()
message.raw_message = raw_data
message.message = []
message.sender = MessageMember(
user_id=sender_info["sender_id"],
nickname=sender_info["nickname"],
)
if room_id:
session_prefix = "room"
session_id = f"{session_prefix}%{room_id}"
if unique_session:
session_id += f"_{sender_info['sender_id']}"
message.type = MessageType.GROUP_MESSAGE
message.group_id = room_id
elif is_chat:
session_prefix = "chat"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
else:
session_prefix = "note"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
message.session_id = (
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
)
message.message_id = str(raw_data.get("id", ""))
message.self_id = client_self_id
return message
def process_at_mention(
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
) -> Tuple[List[str], str]:
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
message_parts = []
if not raw_text:
return message_parts, ""
if bot_username and raw_text.startswith(f"@{bot_username}"):
at_mention = f"@{bot_username}"
message.message.append(Comp.At(qq=client_self_id))
remaining_text = raw_text[len(at_mention) :].strip()
if remaining_text:
message.message.append(Comp.Plain(remaining_text))
message_parts.append(remaining_text)
return message_parts, remaining_text
else:
message.message.append(Comp.Plain(raw_text))
message_parts.append(raw_text)
return message_parts, raw_text
def cache_user_info(
user_cache: Dict[str, Any],
sender_info: Dict[str, Any],
raw_data: Dict[str, Any],
client_self_id: str,
is_chat: bool = False,
):
"""缓存用户信息"""
if is_chat:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": "specified",
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
}
else:
user_cache_data = {
"username": sender_info["username"],
"nickname": sender_info["nickname"],
"visibility": raw_data.get("visibility", "public"),
"visible_user_ids": raw_data.get("visibleUserIds", []),
}
user_cache[sender_info["sender_id"]] = user_cache_data
def cache_room_info(
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
):
"""缓存房间信息"""
room_data = raw_data.get("toRoom")
room_id = raw_data.get("toRoomId")
if room_data and room_id:
room_cache_key = f"room:{room_id}"
user_cache[room_cache_key] = {
"room_id": room_id,
"room_name": room_data.get("name", ""),
"room_description": room_data.get("description", ""),
"owner_id": room_data.get("ownerId", ""),
"visibility": "specified",
"visible_user_ids": [client_self_id],
}

View File

@@ -94,10 +94,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text,
image_base64,
image_path,
record_file_path
record_file_path,
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
if not plain_text and not image_base64 and not image_path and not record_file_path:
if (
not plain_text
and not image_base64
and not image_path
and not record_file_path
):
return
payload = {
@@ -118,7 +123,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # group record msg
if record_file_path: # group record msg
media = await self.upload_group_and_c2c_record(
record_file_path, 3, group_openid=source.group_openid
)
@@ -134,9 +139,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # c2c record
if record_file_path: # c2c record
media = await self.upload_group_and_c2c_record(
record_file_path, 3, openid = source.author.user_openid
record_file_path, 3, openid=source.author.user_openid
)
payload["media"] = media
payload["msg_type"] = 7
@@ -190,58 +195,55 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return await self.bot.api._http.request(route, json=payload)
async def upload_group_and_c2c_record(
self,
file_source: str,
file_type: int,
srv_send_msg: bool = False,
**kwargs
self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs
) -> Optional[Media]:
"""
上传媒体文件
"""
# 构建基础payload
payload = {
"file_type": file_type,
"srv_send_msg": srv_send_msg
}
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
# 处理文件数据
if os.path.exists(file_source):
# 读取本地文件
async with aiofiles.open(file_source, 'rb') as f:
async with aiofiles.open(file_source, "rb") as f:
file_content = await f.read()
# use base64 encode
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
payload["file_data"] = base64.b64encode(file_content).decode("utf-8")
else:
# 使用URL
payload["url"] = file_source
# 添加接收者信息和确定路由
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
elif "group_openid" in kwargs:
payload["group_openid"] =kwargs["group_openid"]
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
payload["group_openid"] = kwargs["group_openid"]
route = Route(
"POST",
"/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"],
)
else:
return None
try:
# 使用底层HTTP请求
result = await self.bot.api._http.request(route, json=payload)
if result:
return Media(
file_uuid=result.get("file_uuid"),
file_info=result.get("file_info"),
ttl=result.get("ttl", 0),
file_id=result.get("id", "")
file_id=result.get("id", ""),
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
return None
async def post_c2c_message(
self,
openid: str,
@@ -286,19 +288,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
record_wav_path = await i.convert_to_file_path() # wav 路径
record_wav_path = await i.convert_to_file_path() # wav 路径
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
record_tecent_silk_path = os.path.join(
temp_dir, f"{uuid.uuid4()}.silk"
)
try:
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
duration = await wav_to_tencent_silk(
record_wav_path, record_tecent_silk_path
)
if duration > 0:
record_file_path = record_tecent_silk_path
else:
record_file_path = None
record_file_path = None
logger.error("转换音频格式时出错音频时长不大于0")
except Exception as e:
logger.error(f"处理语音时出错: {e}")
record_file_path = None
record_file_path = None
else:
logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path, record_file_path

View File

@@ -17,7 +17,14 @@ from astrbot.api.platform import (
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.api.message_components import Plain, Image, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
At,
File,
Record,
Reply,
)
from xml.etree import ElementTree as ET
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
)
self.token = self.config.get("satori_token", "")
self.endpoint = self.config.get(
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
)
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
self.metadata = PlatformMetadata(
name="satori",
description="Satori 通用协议适配器",
id=self.config["id"],
)
self.ws: Optional[ClientConnection] = None
self.session: Optional[ClientSession] = None
self.sequence = 0
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
return self.metadata
def _is_websocket_closed(self, ws) -> bool:
"""检查WebSocket连接是否已关闭"""
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
abm.self_id = login.get("user", {}).get("id", "")
content = message.get("content", "")
abm.message = await self.parse_satori_elements(content)
# 消息链
abm.message = []
content = message.get("content", "")
quote = message.get("quote")
content_for_parsing = content # 副本
# 提取<quote>标签
if "<quote" in content:
try:
quote_info = await self._extract_quote_element(content)
if quote_info:
quote = quote_info["quote"]
content_for_parsing = quote_info["content_without_quote"]
except Exception as e:
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
if quote:
# 引用消息
quote_abm = await self._convert_quote_message(quote)
if quote_abm:
sender_id = quote_abm.sender.user_id
if isinstance(sender_id, str) and sender_id.isdigit():
sender_id = int(sender_id)
elif not isinstance(sender_id, int):
sender_id = 0 # 默认值
reply_component = Reply(
id=quote_abm.message_id,
chain=quote_abm.message,
sender_id=quote_abm.sender.user_id,
sender_nickname=quote_abm.sender.nickname,
time=quote_abm.timestamp,
message_str=quote_abm.message_str,
text=quote_abm.message_str,
qq=sender_id,
)
abm.message.append(reply_component)
# 解析消息内容
content_elements = await self.parse_satori_elements(content_for_parsing)
abm.message.extend(content_elements)
# parse message_str
abm.message_str = ""
for comp in abm.message:
for comp in content_elements:
if isinstance(comp, Plain):
abm.message_str += comp.text
@@ -333,6 +386,163 @@ class SatoriPlatformAdapter(Platform):
logger.error(f"转换 Satori 消息失败: {e}")
return None
def _extract_namespace_prefixes(self, content: str) -> set:
"""提取XML内容中的命名空间前缀"""
prefixes = set()
# 查找所有标签
i = 0
while i < len(content):
# 查找开始标签
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 1 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content and "xmlns:" not in tag_content:
# 分割标签名
parts = tag_content.split()
if parts:
tag_name = parts[0]
if ":" in tag_name:
prefix = tag_name.split(":")[0]
# 确保是有效的命名空间前缀
if (
prefix.isalnum()
or prefix.replace("_", "").isalnum()
):
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
# 查找结束标签
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
# 找到标签结束位置
tag_end = content.find(">", i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 2 : tag_end]
# 检查是否有命名空间前缀
if ":" in tag_content:
prefix = tag_content.split(":")[0]
# 确保是有效的命名空间前缀
if prefix.isalnum() or prefix.replace("_", "").isalnum():
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
else:
i += 1
return prefixes
async def _extract_quote_element(self, content: str) -> Optional[dict]:
"""提取<quote>标签信息"""
try:
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
# 查找<quote>标签
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
if tag_name.lower() == "quote":
quote_element = elem
break
if quote_element is not None:
# 提取quote标签的属性
quote_id = quote_element.get("id", "")
# 提取<quote>标签内部的内容
inner_content = ""
if quote_element.text:
inner_content += quote_element.text
for child in quote_element:
inner_content += ET.tostring(
child, encoding="unicode", method="xml"
)
if child.tail:
inner_content += child.tail
# 构造移除了<quote>标签的内容
content_without_quote = content.replace(
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
)
return {
"quote": {"id": quote_id, "content": inner_content},
"content_without_quote": content_without_quote,
}
return None
except Exception as e:
logger.error(f"提取<quote>标签时发生错误: {e}")
return None
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
"""转换引用消息"""
try:
quote_abm = AstrBotMessage()
quote_abm.message_id = quote.get("id", "")
# 解析引用消息的发送者
quote_author = quote.get("author", {})
if quote_author:
quote_abm.sender = MessageMember(
user_id=quote_author.get("id", ""),
nickname=quote_author.get("nick", quote_author.get("name", "")),
)
else:
# 如果没有作者信息,使用默认值
quote_abm.sender = MessageMember(
user_id=quote.get("user_id", ""),
nickname="内容",
)
# 解析引用消息内容
quote_content = quote.get("content", "")
quote_abm.message = await self.parse_satori_elements(quote_content)
quote_abm.message_str = ""
for comp in quote_abm.message:
if isinstance(comp, Plain):
quote_abm.message_str += comp.text
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
# 如果没有任何内容,使用默认文本
if not quote_abm.message_str.strip():
quote_abm.message_str = "[引用消息]"
return quote_abm
except Exception as e:
logger.error(f"转换引用消息失败: {e}")
return None
async def parse_satori_elements(self, content: str) -> list:
"""解析 Satori 消息元素"""
elements = []
@@ -341,12 +551,35 @@ class SatoriPlatformAdapter(Platform):
return elements
try:
wrapped_content = f"<root>{content}</root>"
root = ET.fromstring(wrapped_content)
# 处理命名空间前缀问题
processed_content = content
if ":" in content and not content.startswith("<root"):
prefixes = self._extract_namespace_prefixes(content)
# 构建命名空间声明
ns_declarations = " ".join(
[
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
for prefix in prefixes
]
)
# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith("<root"):
processed_content = f"<root>{content}</root>"
else:
processed_content = content
root = ET.fromstring(processed_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
# 如果解析失败,将整个内容当作纯文本
if content.strip():
elements.append(Plain(text=content))
except Exception as e:
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
raise e
# 如果没有解析到任何元素,将整个内容当作纯文本
@@ -361,7 +594,12 @@ class SatoriPlatformAdapter(Platform):
elements.append(Plain(text=node.text))
for child in node:
tag_name = child.tag.lower()
# 获取标签名,去除命名空间前缀
tag_name = child.tag
if "}" in tag_name:
tag_name = tag_name.split("}")[1]
tag_name = tag_name.lower()
attrs = child.attrib
if tag_name == "at":
@@ -372,31 +610,59 @@ class SatoriPlatformAdapter(Platform):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:image/"):
src = src.split(",")[1]
elements.append(Image.fromBase64(src))
elif src.startswith("http"):
elements.append(Image.fromURL(src))
else:
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
elements.append(Image(file=src))
elif tag_name == "file":
src = attrs.get("src", "")
name = attrs.get("name", "文件")
if src:
elements.append(File(file=src, name=name))
elements.append(File(name=name, file=src))
elif tag_name in ("audio", "record"):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:audio/"):
src = src.split(",")[1]
elements.append(Record.fromBase64(src))
elif src.startswith("http"):
elements.append(Record.fromURL(src))
elements.append(Record(file=src))
elif tag_name == "quote":
# quote标签已经被特殊处理
pass
elif tag_name == "face":
face_id = attrs.get("id", "")
face_name = attrs.get("name", "")
face_type = attrs.get("type", "")
if face_name:
elements.append(Plain(text=f"[表情:{face_name}]"))
elif face_id and face_type:
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
elif face_id:
elements.append(Plain(text=f"[表情ID:{face_id}]"))
else:
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
elements.append(Plain(text="[表情]"))
elif tag_name == "ark":
# 作为纯文本添加到消息链中
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[ARK卡片]"))
elif tag_name == "json":
# JSON标签 视为ARK卡片消息
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[JSON卡片]"))
else:
# 未知标签,递归处理其内容

View File

@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, At, File, Record
from astrbot.api.message_components import Plain, Image, At, File, Record, Video, Reply
if TYPE_CHECKING:
from .satori_adapter import SatoriPlatformAdapter
@@ -17,6 +17,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
session_id: str,
adapter: "SatoriPlatformAdapter",
):
# 更新平台元数据
if adapter and hasattr(adapter, "logins") and adapter.logins:
current_login = adapter.logins[0]
platform_name = current_login.get("platform", "satori")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
if not platform_meta.id and user_id:
platform_meta.id = f"{platform_name}({user_id})"
super().__init__(message_str, message_obj, platform_meta, session_id)
self.adapter = adapter
self.platform = None
@@ -78,6 +87,17 @@ class SatoriPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
content_parts.append(f'<reply id="{component.id}"/>')
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
content_parts.append(f'<video src="{video_path_url}"/>')
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
content = "".join(content_parts)
channel_id = session_id
data = {"channel_id": channel_id, "content": content}
@@ -157,6 +177,17 @@ class SatoriPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
content_parts.append(f'<reply id="{component.id}"/>')
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
content_parts.append(f'<video src="{video_path_url}"/>')
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
content = "".join(content_parts)
channel_id = self.session_id
data = {"channel_id": channel_id, "content": content}

View File

@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
base64_content = base64.b64encode(content).decode("utf-8")
return base64_content
else:
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
logger.error(
f"Failed to download slack file: {resp.status} {await resp.text()}"
)
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:

View File

@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
return {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
else:
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}

View File

@@ -95,9 +95,8 @@ class TelegramPlatformAdapter(Platform):
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="telegram", description="telegram 适配器", id=self.config.get("id")
)
id_ = self.config.get("id") or "telegram"
return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_)
@override
async def run(self):
@@ -117,6 +116,10 @@ class TelegramPlatformAdapter(Platform):
)
self.scheduler.start()
if not self.application.updater:
logger.error("Telegram Updater is not initialized. Cannot start polling.")
return
queue = self.application.updater.start_polling()
logger.info("Telegram Platform Adapter is running.")
await queue
@@ -194,6 +197,11 @@ class TelegramPlatformAdapter(Platform):
return cmd_name, description
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.effective_chat:
logger.warning(
"Received a start command without an effective chat, skipping /start reply."
)
return
await context.bot.send_message(
chat_id=update.effective_chat.id, text=self.config["start_message"]
)
@@ -206,15 +214,20 @@ class TelegramPlatformAdapter(Platform):
async def convert_message(
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
) -> AstrBotMessage:
) -> AstrBotMessage | None:
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
@param update: Telegram 的 Update 对象。
@param context: Telegram 的 Context 对象。
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
if not update.message:
logger.warning("Received an update without a message.")
return None
message = AstrBotMessage()
message.session_id = str(update.message.chat.id)
# 获得是群聊还是私聊
if update.message.chat.type == ChatType.PRIVATE:
message.type = MessageType.FRIEND_MESSAGE
@@ -225,10 +238,13 @@ class TelegramPlatformAdapter(Platform):
# Topic Group
message.group_id += "#" + str(update.message.message_thread_id)
message.session_id = message.group_id
message.message_id = str(update.message.message_id)
_from_user = update.message.from_user
if not _from_user:
logger.warning("[Telegram] Received a message without a from_user.")
return None
message.sender = MessageMember(
str(update.message.from_user.id), update.message.from_user.username
str(_from_user.id), _from_user.username or "Unknown"
)
message.self_id = str(context.bot.username)
message.raw_message = update
@@ -247,22 +263,32 @@ class TelegramPlatformAdapter(Platform):
)
reply_abm = await self.convert_message(reply_update, context, False)
message.message.append(
Comp.Reply(
id=reply_abm.message_id,
chain=reply_abm.message,
sender_id=reply_abm.sender.user_id,
sender_nickname=reply_abm.sender.nickname,
time=reply_abm.timestamp,
message_str=reply_abm.message_str,
text=reply_abm.message_str,
qq=reply_abm.sender.user_id,
if reply_abm:
message.message.append(
Comp.Reply(
id=reply_abm.message_id,
chain=reply_abm.message,
sender_id=reply_abm.sender.user_id,
sender_nickname=reply_abm.sender.nickname,
time=reply_abm.timestamp,
message_str=reply_abm.message_str,
text=reply_abm.message_str,
qq=reply_abm.sender.user_id,
)
)
)
if update.message.text:
# 处理文本消息
plain_text = update.message.text
if (
message.type == MessageType.GROUP_MESSAGE
and update.message
and update.message.reply_to_message
and update.message.reply_to_message.from_user
and update.message.reply_to_message.from_user.id == context.bot.id
):
plain_text2 = f"/@{context.bot.username} " + plain_text
plain_text = plain_text2
# 群聊场景命令特殊处理
if plain_text.startswith("/"):
@@ -328,15 +354,25 @@ class TelegramPlatformAdapter(Platform):
elif update.message.document:
file = await update.message.document.get_file()
message.message = [
Comp.File(file=file.file_path, name=update.message.document.file_name),
]
file_name = update.message.document.file_name or uuid.uuid4().hex
file_path = file.file_path
if file_path is None:
logger.warning(
f"Telegram document file_path is None, cannot save the file {file_name}."
)
else:
message.message.append(Comp.File(file=file_path, name=file_name))
elif update.message.video:
file = await update.message.video.get_file()
message.message = [
Comp.Video(file=file.file_path, path=file.file_path),
]
file_name = update.message.video.file_name or uuid.uuid4().hex
file_path = file.file_path
if file_path is None:
logger.warning(
f"Telegram video file_path is None, cannot save the file {file_name}."
)
else:
message.message.append(Comp.Video(file=file_path, path=file.file_path))
return message

View File

@@ -16,6 +16,7 @@ from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from telegram import ReactionTypeEmoji, ReactionTypeCustomEmoji
class TelegramPlatformEvent(AstrMessageEvent):
@@ -66,7 +67,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
return chunks
@classmethod
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
async def send_with_client(
cls, client: ExtBot, message: MessageChain, user_name: str
):
image_path = None
has_reply = False
@@ -133,6 +136,39 @@ class TelegramPlatformEvent(AstrMessageEvent):
await self.send_with_client(self.client, message, self.get_sender_id())
await super().send(message)
async def react(self, emoji: str | None, big: bool = False):
"""
给原消息添加 Telegram 反应:
- 普通 emoji传入 '👍''😂'
- 自定义表情:传入其 custom_emoji_id纯数字字符串
- 取消本机器人的反应:传入 None 或空字符串
"""
try:
# 解析 chat_id去掉超级群的 "#<thread_id>" 片段)
if self.get_message_type() == MessageType.GROUP_MESSAGE:
chat_id = (self.message_obj.group_id or "").split("#")[0]
else:
chat_id = self.get_sender_id()
message_id = int(self.message_obj.message_id)
# 组装 reaction 参数(必须是 ReactionType 的列表)
if not emoji: # 清空本 bot 的反应
reaction_param = [] # 空列表表示移除本 bot 的反应
elif emoji.isdigit(): # 自定义表情:传 custom_emoji_id
reaction_param = [ReactionTypeCustomEmoji(emoji)]
else: # 普通 emoji
reaction_param = [ReactionTypeEmoji(emoji)]
await self.client.set_message_reaction(
chat_id=chat_id,
message_id=message_id,
reaction=reaction_param, # 注意是列表
is_big=big, # 可选:大动画
)
except Exception as e:
logger.error(f"[Telegram] 添加反应失败: {e}")
async def send_streaming(self, generator, use_fallback: bool = False):
message_thread_id = None
@@ -216,7 +252,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
delta = ""
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id

View File

@@ -91,7 +91,6 @@ class WebChatAdapter(Platform):
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE

View File

@@ -1,5 +1,6 @@
import asyncio
class WebChatQueueMgr:
def __init__(self) -> None:
self.queues = {}
@@ -30,4 +31,5 @@ class WebChatQueueMgr:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
webchat_queue_mgr = WebChatQueueMgr()

View File

@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
def _extract_auth_key(self, data):
"""Helper method to extract auth_key from response data."""
if isinstance(data, dict):
auth_keys = data.get("authKeys") # 新接口
auth_keys = data.get("authKeys") # 新接口
if isinstance(auth_keys, list) and auth_keys:
return auth_keys[0]
elif isinstance(data, list) and data: # 旧接口
elif isinstance(data, list) and data: # 旧接口
return data[0]
return None
@@ -234,7 +234,9 @@ class WeChatPadProAdapter(Platform):
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
logger.error(
f"生成授权码失败: {response.status}, {await response.text()}"
)
return
response_data = await response.json()
@@ -245,7 +247,9 @@ class WeChatPadProAdapter(Platform):
if self.auth_key:
logger.info("成功获取授权码")
else:
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
logger.error(
f"生成授权码成功但未找到授权码: {response_data}"
)
else:
logger.error(f"生成授权码失败: {response_data}")
except aiohttp.ClientConnectorError as e:

View File

@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
return PlatformMetadata(
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
)
@override

View File

@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。
:return: 接口调用结果
"""
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
data = {
"token": token,
"cursor": cursor,
"limit": limit,
"open_kfid": open_kfid,
}
return self._post("kf/sync_msg", data=data)
def get_service_state(self, open_kfid, external_userid):
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
}
return self._post("kf/service_state/get", data=data)
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
def trans_service_state(
self, open_kfid, external_userid, service_state, servicer_userid=""
):
"""
变更会话状态
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
"""
return self._get("kf/customer/get_upgrade_service_config")
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
def upgrade_service(
self, open_kfid, external_userid, service_type, member=None, groupchat=None
):
"""
为客户升级为专员或客户群服务
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
return self._post("kf/get_corp_statistic", data=data)
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
def get_servicer_statistic(
self, start_time, end_time, open_kfid=None, servicer_userid=None
):
"""
获取「客户数据统计」接待人员明细数据

View File

@@ -26,6 +26,7 @@ from optionaldict import optionaldict
from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKFMessage(BaseWeChatAPI):
"""
发送微信客服消息
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
msg={"msgtype": "news", "link": {"link": articles_data}},
)
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
def send_msgmenu(
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "msgmenu",
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
"msgmenu": {
"head_content": head_content,
"list": menu_list,
"tail_content": tail_content,
},
},
)
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
def send_location(
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "location",
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
"msgmenu": {
"name": name,
"address": address,
"latitude": latitude,
"longitude": longitude,
},
},
)
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
def send_miniprogram(
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "miniprogram",
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
"msgmenu": {
"appid": appid,
"title": title,
"thumb_media_id": thumb_media_id,
"pagepath": pagepath,
},
},
)

View File

@@ -0,0 +1,289 @@
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
"""对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2020 Tencent Inc.
"""
# ------------------------------------------------------------------------
import logging
import base64
import random
import hashlib
import time
import struct
from Crypto.Cipher import AES
import socket
import json
from . import ierror
"""
关于Crypto.Cipher模块ImportError: No module named 'Crypto'解决方案
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
下载后按照README中的“Installation”小节的提示进行pycrypto安装。
"""
class FormatException(Exception):
pass
def throw_exception(message, exception_class=FormatException):
"""my define raise exception function"""
raise exception_class(message)
class SHA1:
"""计算企业微信的消息签名接口"""
def getSHA1(self, token, timestamp, nonce, encrypt):
"""用SHA1算法生成安全签名
@param token: 票据
@param timestamp: 时间戳
@param encrypt: 密文
@param nonce: 随机字符串
@return: 安全签名
"""
try:
# 确保所有输入都是字符串类型
if isinstance(encrypt, bytes):
encrypt = encrypt.decode("utf-8")
sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist).encode("utf-8"))
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
class JsonParse:
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
# json消息模板
AES_TEXT_RESPONSE_TEMPLATE = """{
"encrypt": "%(msg_encrypt)s",
"msgsignature": "%(msg_signaturet)s",
"timestamp": "%(timestamp)s",
"nonce": "%(nonce)s"
}"""
def extract(self, jsontext):
"""提取出json数据包中的加密消息
@param jsontext: 待提取的json字符串
@return: 提取出的加密消息字符串
"""
try:
json_dict = json.loads(jsontext)
return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"]
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_ParseJson_Error, None
def generate(self, encrypt, signature, timestamp, nonce):
"""生成json消息
@param encrypt: 加密后的消息密文
@param signature: 安全签名
@param timestamp: 时间戳
@param nonce: 随机字符串
@return: 生成的json字符串
"""
resp_dict = {
"msg_encrypt": encrypt,
"msg_signaturet": signature,
"timestamp": timestamp,
"nonce": nonce,
}
resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
return resp_json
class PKCS7Encoder:
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
"""对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文(bytes类型)
@return: 补齐明文字符串(bytes类型)
"""
text_length = len(text)
# 计算需要填充的位数
amount_to_pad = self.block_size - (text_length % self.block_size)
if amount_to_pad == 0:
amount_to_pad = self.block_size
# 获得补位所用的字符
pad = bytes([amount_to_pad])
# 确保text是bytes类型
if isinstance(text, str):
text = text.encode("utf-8")
return text + pad * amount_to_pad
def decode(self, decrypted):
"""删除解密后明文的补位字符
@param decrypted: 解密后的明文
@return: 删除补位字符后的明文
"""
pad = ord(decrypted[-1])
if pad < 1 or pad > 32:
pad = 0
return decrypted[:-pad]
class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
self.mode = AES.MODE_CBC
def encrypt(self, text, receiveid):
"""对明文进行加密
@param text: 需要加密的明文
@return: 加密得到的字符串
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = (
self.get_random_str()
+ struct.pack("I", socket.htonl(len(text)))
+ text
+ receiveid.encode()
)
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
text = pkcs7.encode(text)
# 加密
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
try:
ciphertext = cryptor.encrypt(text)
# 使用BASE64对加密后的字符串进行编码
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
except Exception as e:
logger = logging.getLogger("astrbot")
logger.error(e)
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
def decrypt(self, text, receiveid):
"""对解密后的明文进行补位删除
@param text: 密文
@return: 删除填充补位后的明文
"""
try:
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
# 使用BASE64对密文进行解码然后AES-CBC解密
plain_text = cryptor.decrypt(base64.b64decode(text))
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
try:
pad = plain_text[-1]
# 去掉补位字符串
# pkcs7 = PKCS7Encoder()
# plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
json_len = socket.ntohl(struct.unpack("I", content[:4])[0])
json_content = content[4 : json_len + 4].decode("utf-8")
from_receiveid = content[json_len + 4 :].decode("utf-8")
except Exception as e:
print(e)
return ierror.WXBizMsgCrypt_IllegalBuffer, None
if from_receiveid != receiveid:
print("receiveid not match", receiveid, from_receiveid)
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
return 0, json_content
def get_random_str(self):
"""随机生成16位字符串
@return: 16位字符串
"""
return str(random.randint(1000000000000000, 9999999999999999)).encode()
class WXBizJsonMsgCrypt(object):
# 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey + "=")
assert len(self.key) == 32
except Exception as e:
throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId
# 验证URL
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sEchoStr: 随机串对应URL参数的echostr
# @param sReplyEchoStr: 解密之后的echostr当return返回0时有效
# @return成功0失败返回对应的错误码
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
return ret, sReplyEchoStr
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
# 将企业回复用户的消息加密打包
# @param sReplyMsg: 企业号待回复用户的消息json格式的字符串
# @param sTimeStamp: 时间戳可以自己生成也可以用URL参数的timestamp,如为None则自动用当前时间
# @param sNonce: 随机串可以自己生成也可以用URL参数的nonce
# sEncryptMsg: 加密后的可以直接回复用户的密文包括msg_signature, timestamp, nonce, encrypt的json格式的字符串,
# return成功0sEncryptMsg,失败返回对应的错误码None
pc = Prpcrypt(self.key)
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
encrypt = encrypt.decode("utf-8") # type: ignore
if ret != 0:
return ret, None
if timestamp is None:
timestamp = str(int(time.time()))
# 生成安全签名
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
if ret != 0:
return ret, None
jsonParse = JsonParse()
return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce)
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
# 检验消息的真实性,并且获取解密后的明文
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sPostData: 密文对应POST请求的数据
# json_content: 解密后的原文当return返回0时有效
# @return: 成功0失败返回对应的错误码
# 验证安全签名
jsonParse = JsonParse()
ret, encrypt = jsonParse.extract(sPostData)
if ret != 0:
return ret, None
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
print("signature not match")
print(signature)
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId)
return ret, json_content

View File

@@ -0,0 +1,17 @@
"""
企业微信智能机器人平台适配器包
"""
from .wecomai_adapter import WecomAIBotAdapter
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_server import WecomAIBotServer
from .wecomai_utils import WecomAIBotConstants
__all__ = [
"WecomAIBotAdapter",
"WecomAIBotAPIClient",
"WecomAIBotMessageEvent",
"WecomAIBotServer",
"WecomAIBotConstants",
]

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#########################################################################
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
WXBizMsgCrypt_ParseJson_Error = -40002
WXBizMsgCrypt_ComputeSignature_Error = -40003
WXBizMsgCrypt_IllegalAesKey = -40004
WXBizMsgCrypt_ValidateCorpid_Error = -40005
WXBizMsgCrypt_EncryptAES_Error = -40006
WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnJson_Error = -40011

View File

@@ -0,0 +1,445 @@
"""
企业微信智能机器人平台适配器
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调
参考webchat_adapter.py的队列机制实现异步消息处理和流式响应
"""
import time
import asyncio
import uuid
import hashlib
import base64
from typing import Awaitable, Any, Dict, Optional, Callable
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, At, Image
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .wecomai_api import (
WecomAIBotAPIClient,
WecomAIBotMessageParser,
WecomAIBotStreamMessageBuilder,
)
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_server import WecomAIBotServer
from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr
from .wecomai_utils import (
WecomAIBotConstants,
format_session_id,
generate_random_string,
process_encrypted_image,
)
class WecomAIQueueListener:
"""企业微信智能机器人队列监听器参考webchat的QueueListener设计"""
def __init__(
self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]]
) -> None:
self.queue_mgr = queue_mgr
self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, session_id: str):
"""监听特定会话的队列"""
queue = self.queue_mgr.get_or_create_queue(session_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
break
async def run(self):
"""监控新会话队列并启动监听器"""
monitored_sessions = set()
while True:
# 检查新会话
current_sessions = set(self.queue_mgr.queues.keys())
new_sessions = current_sessions - monitored_sessions
# 为新会话启动监听器
for session_id in new_sessions:
task = asyncio.create_task(self.listen_to_queue(session_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_sessions.add(session_id)
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
# 清理已不存在的会话
removed_sessions = monitored_sessions - current_sessions
monitored_sessions -= removed_sessions
# 清理过期的待处理响应
self.queue_mgr.cleanup_expired_responses()
await asyncio.sleep(1) # 每秒检查一次新会话
@register_platform_adapter(
"wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息"
)
class WecomAIBotAdapter(Platform):
"""企业微信智能机器人适配器"""
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
# 初始化配置参数
self.token = self.config["token"]
self.encoding_aes_key = self.config["encoding_aes_key"]
self.port = int(self.config["port"])
self.host = self.config.get("callback_server_host", "0.0.0.0")
self.bot_name = self.config.get("wecom_ai_bot_name", "")
self.initial_respond_text = self.config.get(
"wecomaibot_init_respond_text", "💭 思考中..."
)
self.friend_message_welcome_text = self.config.get(
"wecomaibot_friend_message_welcome_text", ""
)
# 平台元数据
self.metadata = PlatformMetadata(
name="wecom_ai_bot",
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
id=self.config.get("id", "wecom_ai_bot"),
)
# 初始化 API 客户端
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
# 初始化 HTTP 服务器
self.server = WecomAIBotServer(
host=self.host,
port=self.port,
api_client=self.api_client,
message_handler=self._process_message,
)
# 事件循环和关闭信号
self.shutdown_event = asyncio.Event()
# 队列监听器
self.queue_listener = WecomAIQueueListener(
wecomai_queue_mgr, self._handle_queued_message
)
async def _handle_queued_message(self, data: dict):
"""处理队列中的消息类似webchat的callback"""
try:
abm = await self.convert_message(data)
await self.handle_msg(abm)
except Exception as e:
logger.error(f"处理队列消息时发生异常: {e}")
async def _process_message(
self, message_data: Dict[str, Any], callback_params: Dict[str, str]
) -> Optional[str]:
"""处理接收到的消息
Args:
message_data: 解密后的消息数据
callback_params: 回调参数 (nonce, timestamp)
Returns:
加密后的响应消息,无需响应时返回 None
"""
msgtype = message_data.get("msgtype")
if not msgtype:
logger.warning(f"消息类型未知,忽略: {message_data}")
return None
session_id = self._extract_session_id(message_data)
if msgtype in ("text", "image", "mixed"):
# user sent a text / image / mixed message
try:
# create a brand-new unique stream_id for this message session
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._enqueue_message(
message_data, callback_params, stream_id, session_id
)
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id, self.initial_respond_text, False
)
return await self.api_client.encrypt_message(
resp, callback_params["nonce"], callback_params["timestamp"]
)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
return None
elif msgtype == "stream":
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
if not wecomai_queue_mgr.has_back_queue(stream_id):
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
# 返回结束标志,告诉微信服务器流已结束
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id, "", True
)
resp = await self.api_client.encrypt_message(
end_message,
callback_params["nonce"],
callback_params["timestamp"],
)
return resp
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if queue.empty():
logger.debug(
f"No new messages in back queue for stream_id: {stream_id}"
)
return None
# aggregate all delta chains in the back queue
latest_plain_content = ""
image_base64 = []
finish = False
while not queue.empty():
msg = await queue.get()
if msg["type"] == "plain":
latest_plain_content = msg["data"] or ""
elif msg["type"] == "image":
image_base64.append(msg["image_data"])
elif msg["type"] == "end":
# stream end
finish = True
wecomai_queue_mgr.remove_queues(stream_id)
break
else:
pass
logger.debug(
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}"
)
if latest_plain_content or image_base64:
msg_items = []
if finish and image_base64:
for img_b64 in image_base64:
# get md5 of image
img_data = base64.b64decode(img_b64)
img_md5 = hashlib.md5(img_data).hexdigest()
msg_items.append(
{
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
"image": {"base64": img_b64, "md5": img_md5},
}
)
image_base64 = []
plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream(
stream_id, latest_plain_content, msg_items, finish
)
encrypted_message = await self.api_client.encrypt_message(
plain_message,
callback_params["nonce"],
callback_params["timestamp"],
)
if encrypted_message:
logger.debug(
f"Stream message sent successfully, stream_id: {stream_id}"
)
else:
logger.error("消息加密失败")
return encrypted_message
return None
elif msgtype == "event":
event = message_data.get("event")
if event == "enter_chat" and self.friend_message_welcome_text:
# 用户进入会话,发送欢迎消息
try:
resp = WecomAIBotStreamMessageBuilder.make_text(
self.friend_message_welcome_text
)
return await self.api_client.encrypt_message(
resp,
callback_params["nonce"],
callback_params["timestamp"],
)
except Exception as e:
logger.error("处理欢迎消息时发生异常: %s", e)
return None
pass
def _extract_session_id(self, message_data: Dict[str, Any]) -> str:
"""从消息数据中提取会话ID"""
user_id = message_data.get("from", {}).get("userid", "default_user")
return format_session_id("wecomai", user_id)
async def _enqueue_message(
self,
message_data: Dict[str, Any],
callback_params: Dict[str, str],
stream_id: str,
session_id: str,
):
"""将消息放入队列进行异步处理"""
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
message_payload = {
"message_data": message_data,
"callback_params": callback_params,
"session_id": session_id,
"stream_id": stream_id,
}
await input_queue.put(message_payload)
logger.debug(f"[WecomAI] 消息已入队: {stream_id}")
async def convert_message(self, payload: dict) -> AstrBotMessage:
"""转换队列中的消息数据为AstrBotMessage类似webchat的convert_message"""
message_data = payload["message_data"]
session_id = payload["session_id"]
# callback_params = payload["callback_params"] # 保留但暂时不使用
# 解析消息内容
msgtype = message_data.get("msgtype")
content = ""
image_base64 = []
_img_url_to_process = []
msg_items = []
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
content = WecomAIBotMessageParser.parse_text_message(message_data)
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
_img_url_to_process.append(
WecomAIBotMessageParser.parse_image_message(message_data)
)
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
# 提取混合消息中的文本内容
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
text_parts = []
for item in msg_items or []:
if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT:
text_content = item.get("text", {}).get("content", "")
if text_content:
text_parts.append(text_content)
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
image_url = item.get("image", {}).get("url", "")
if image_url:
_img_url_to_process.append(image_url)
content = " ".join(text_parts) if text_parts else ""
else:
content = f"[{msgtype}消息]"
# 并行处理图片下载和解密
if _img_url_to_process:
tasks = [
process_encrypted_image(url, self.encoding_aes_key)
for url in _img_url_to_process
]
results = await asyncio.gather(*tasks)
for success, result in results:
if success:
image_base64.append(result)
else:
logger.error(f"处理加密图片失败: {result}")
# 构建 AstrBotMessage
abm = AstrBotMessage()
abm.self_id = self.bot_name
abm.message_str = content or "[未知消息]"
abm.message_id = str(uuid.uuid4())
abm.timestamp = int(time.time())
abm.raw_message = payload
# 发送者信息
abm.sender = MessageMember(
user_id=message_data.get("from", {}).get("userid", "unknown"),
nickname=message_data.get("from", {}).get("userid", "unknown"),
)
# 消息类型
abm.type = (
MessageType.GROUP_MESSAGE
if message_data.get("chattype") == "group"
else MessageType.FRIEND_MESSAGE
)
abm.session_id = session_id
# 消息内容
abm.message = []
# 处理 At
if self.bot_name and f"@{self.bot_name}" in abm.message_str:
abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip()
abm.message.append(At(qq=self.bot_name, name=self.bot_name))
abm.message.append(Plain(abm.message_str))
if image_base64:
for img_b64 in image_base64:
abm.message.append(Image.fromBase64(img_b64))
logger.debug(f"WecomAIAdapter: {abm.message}")
return abm
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
"""通过会话发送消息"""
# 企业微信智能机器人主要通过回调响应,这里记录日志
logger.info("会话发送消息: %s -> %s", session.session_id, message_chain)
await super().send_by_session(session, message_chain)
def run(self) -> Awaitable[Any]:
"""运行适配器同时启动HTTP服务器和队列监听器"""
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
async def run_both():
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.queue_listener.run(),
)
return run_both()
async def terminate(self):
"""终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.server.shutdown()
def meta(self) -> PlatformMetadata:
"""获取平台元数据"""
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
"""处理消息,创建消息事件并提交到事件队列"""
try:
message_event = WecomAIBotMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
api_client=self.api_client,
)
self.commit_event(message_event)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
def get_client(self) -> WecomAIBotAPIClient:
"""获取 API 客户端"""
return self.api_client
def get_server(self) -> WecomAIBotServer:
"""获取 HTTP 服务器实例"""
return self.server

View File

@@ -0,0 +1,378 @@
"""
企业微信智能机器人 API 客户端
处理消息加密解密、API 调用等
"""
import json
import base64
import hashlib
from typing import Dict, Any, Optional, Tuple, Union
from Crypto.Cipher import AES
import aiohttp
from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt
from .wecomai_utils import WecomAIBotConstants
from astrbot import logger
class WecomAIBotAPIClient:
"""企业微信智能机器人 API 客户端"""
def __init__(self, token: str, encoding_aes_key: str):
"""初始化 API 客户端
Args:
token: 企业微信机器人 Token
encoding_aes_key: 消息加密密钥
"""
self.token = token
self.encoding_aes_key = encoding_aes_key
self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串
async def decrypt_message(
self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str
) -> Tuple[int, Optional[Dict[str, Any]]]:
"""解密企业微信消息
Args:
encrypted_data: 加密的消息数据
msg_signature: 消息签名
timestamp: 时间戳
nonce: 随机数
Returns:
(错误码, 解密后的消息数据字典)
"""
try:
ret, decrypted_msg = self.wxcpt.DecryptMsg(
encrypted_data, msg_signature, timestamp, nonce
)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"消息解密失败,错误码: {ret}")
return ret, None
# 解析 JSON
if decrypted_msg:
try:
message_data = json.loads(decrypted_msg)
logger.debug(f"解密成功,消息内容: {message_data}")
return WecomAIBotConstants.SUCCESS, message_data
except json.JSONDecodeError as e:
logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}")
return WecomAIBotConstants.PARSE_XML_ERROR, None
else:
logger.error("解密消息为空")
return WecomAIBotConstants.DECRYPT_ERROR, None
except Exception as e:
logger.error(f"解密过程发生异常: {e}")
return WecomAIBotConstants.DECRYPT_ERROR, None
async def encrypt_message(
self, plain_message: str, nonce: str, timestamp: str
) -> Optional[str]:
"""加密消息
Args:
plain_message: 明文消息
nonce: 随机数
timestamp: 时间戳
Returns:
加密后的消息,失败时返回 None
"""
try:
ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"消息加密失败,错误码: {ret}")
return None
logger.debug("消息加密成功")
return encrypted_msg
except Exception as e:
logger.error(f"加密过程发生异常: {e}")
return None
def verify_url(
self, msg_signature: str, timestamp: str, nonce: str, echostr: str
) -> str:
"""验证回调 URL
Args:
msg_signature: 消息签名
timestamp: 时间戳
nonce: 随机数
echostr: 验证字符串
Returns:
验证结果字符串
"""
try:
ret, echo_result = self.wxcpt.VerifyURL(
msg_signature, timestamp, nonce, echostr
)
if ret != WecomAIBotConstants.SUCCESS:
logger.error(f"URL 验证失败,错误码: {ret}")
return "verify fail"
logger.info("URL 验证成功")
return echo_result if echo_result else "verify fail"
except Exception as e:
logger.error(f"URL 验证发生异常: {e}")
return "verify fail"
async def process_encrypted_image(
self, image_url: str, aes_key_base64: Optional[str] = None
) -> Tuple[bool, Union[bytes, str]]:
"""下载并解密加密图片
Args:
image_url: 加密图片的 URL
aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥
Returns:
(是否成功, 图片数据或错误信息)
"""
try:
# 下载图片
logger.info(f"开始下载加密图片: {image_url}")
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=15) as response:
if response.status != 200:
error_msg = f"图片下载失败,状态码: {response.status}"
logger.error(error_msg)
return False, error_msg
encrypted_data = await response.read()
logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节")
# 准备解密密钥
if aes_key_base64 is None:
aes_key_base64 = self.encoding_aes_key
if not aes_key_base64:
raise ValueError("AES 密钥不能为空")
# Base64 解码密钥
aes_key = base64.b64decode(
aes_key_base64 + "=" * (-len(aes_key_base64) % 4)
)
if len(aes_key) != 32:
raise ValueError("无效的 AES 密钥长度: 应为 32 字节")
iv = aes_key[:16] # 初始向量为密钥前 16 字节
# 解密图片数据
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = cipher.decrypt(encrypted_data)
# 去除 PKCS#7 填充
pad_len = decrypted_data[-1]
if pad_len > 32: # AES-256 块大小为 32 字节
raise ValueError("无效的填充长度 (大于32字节)")
decrypted_data = decrypted_data[:-pad_len]
logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节")
return True, decrypted_data
except aiohttp.ClientError as e:
error_msg = f"图片下载失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
except ValueError as e:
error_msg = f"参数错误: {str(e)}"
logger.error(error_msg)
return False, error_msg
except Exception as e:
error_msg = f"图片处理异常: {str(e)}"
logger.error(error_msg)
return False, error_msg
class WecomAIBotStreamMessageBuilder:
"""企业微信智能机器人流消息构建器"""
@staticmethod
def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str:
"""构建文本流消息
Args:
stream_id: 流 ID
content: 文本内容
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {"id": stream_id, "finish": finish, "content": content},
}
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_image_stream(
stream_id: str, image_data: bytes, finish: bool = False
) -> str:
"""构建图片流消息
Args:
stream_id: 流 ID
image_data: 图片二进制数据
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
image_md5 = hashlib.md5(image_data).hexdigest()
image_base64 = base64.b64encode(image_data).decode("utf-8")
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {
"id": stream_id,
"finish": finish,
"msg_item": [
{
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
"image": {"base64": image_base64, "md5": image_md5},
}
],
},
}
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_mixed_stream(
stream_id: str, content: str, msg_items: list, finish: bool = False
) -> str:
"""构建混合类型流消息
Args:
stream_id: 流 ID
content: 文本内容
msg_items: 消息项列表
finish: 是否结束
Returns:
JSON 格式的流消息字符串
"""
plain = {
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
"stream": {"id": stream_id, "finish": finish, "msg_item": msg_items},
}
if content:
plain["stream"]["content"] = content
return json.dumps(plain, ensure_ascii=False)
@staticmethod
def make_text(content: str) -> str:
"""构建文本消息
Args:
content: 文本内容
Returns:
JSON 格式的文本消息字符串
"""
plain = {"msgtype": "text", "text": {"content": content}}
return json.dumps(plain, ensure_ascii=False)
class WecomAIBotMessageParser:
"""企业微信智能机器人消息解析器"""
@staticmethod
def parse_text_message(data: Dict[str, Any]) -> Optional[str]:
"""解析文本消息
Args:
data: 消息数据
Returns:
文本内容,解析失败返回 None
"""
try:
return data.get("text", {}).get("content")
except (KeyError, TypeError):
logger.warning("文本消息解析失败")
return None
@staticmethod
def parse_image_message(data: Dict[str, Any]) -> Optional[str]:
"""解析图片消息
Args:
data: 消息数据
Returns:
图片 URL解析失败返回 None
"""
try:
return data.get("image", {}).get("url")
except (KeyError, TypeError):
logger.warning("图片消息解析失败")
return None
@staticmethod
def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""解析流消息
Args:
data: 消息数据
Returns:
流消息数据,解析失败返回 None
"""
try:
stream_data = data.get("stream", {})
return {
"id": stream_data.get("id"),
"finish": stream_data.get("finish"),
"content": stream_data.get("content"),
"msg_item": stream_data.get("msg_item", []),
}
except (KeyError, TypeError):
logger.warning("流消息解析失败")
return None
@staticmethod
def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]:
"""解析混合消息
Args:
data: 消息数据
Returns:
消息项列表,解析失败返回 None
"""
try:
return data.get("mixed", {}).get("msg_item", [])
except (KeyError, TypeError):
logger.warning("混合消息解析失败")
return None
@staticmethod
def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""解析事件消息
Args:
data: 消息数据
Returns:
事件数据,解析失败返回 None
"""
try:
return data.get("event", {})
except (KeyError, TypeError):
logger.warning("事件消息解析失败")
return None

View File

@@ -0,0 +1,149 @@
"""
企业微信智能机器人事件处理模块,处理消息事件的发送和接收
"""
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
Plain,
)
from astrbot.api import logger
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_queue_mgr import wecomai_queue_mgr
class WecomAIBotMessageEvent(AstrMessageEvent):
"""企业微信智能机器人消息事件"""
def __init__(
self,
message_str: str,
message_obj,
platform_meta,
session_id: str,
api_client: WecomAIBotAPIClient,
):
"""初始化消息事件
Args:
message_str: 消息字符串
message_obj: 消息对象
platform_meta: 平台元数据
session_id: 会话 ID
api_client: API 客户端
"""
super().__init__(message_str, message_obj, platform_meta, session_id)
self.api_client = api_client
@staticmethod
async def _send(
message_chain: MessageChain,
stream_id: str,
streaming: bool = False,
):
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
if not message_chain:
await back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
}
)
return ""
data = ""
for comp in message_chain.chain:
if isinstance(comp, Plain):
data = comp.text
await back_queue.put(
{
"type": "plain",
"data": data,
"streaming": streaming,
"session_id": stream_id,
}
)
elif isinstance(comp, Image):
# 处理图片消息
try:
image_base64 = await comp.convert_to_base64()
if image_base64:
await back_queue.put(
{
"type": "image",
"image_data": image_base64,
"streaming": streaming,
"session_id": stream_id,
}
)
else:
logger.warning("图片数据为空,跳过")
except Exception as e:
logger.error("处理图片消息失败: %s", e)
else:
logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过")
return data
async def send(self, message: MessageChain):
"""发送消息"""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id)
await super().send(message)
async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息参考webchat的send_streaming设计"""
final_data = ""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
increment_plain = ""
async for chain in generator:
# 累积增量内容,并改写 Plain 段
chain.squash_plain()
for comp in chain.chain:
if isinstance(comp, Plain):
comp.text = increment_plain + comp.text
increment_plain = comp.text
break
if chain.type == "break" and final_data:
# 分割符
await back_queue.put(
{
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"session_id": self.session_id,
}
)
final_data = ""
continue
final_data += await WecomAIBotMessageEvent._send(
chain,
stream_id=stream_id,
streaming=True,
)
await back_queue.put(
{
"type": "complete", # complete means we return the final result
"data": final_data,
"streaming": True,
"session_id": self.session_id,
}
)
await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,148 @@
"""
企业微信智能机器人队列管理器
参考 webchat_queue_mgr.py为企业微信智能机器人实现队列机制
支持异步消息处理和流式响应
"""
import asyncio
from typing import Dict, Any, Optional
from astrbot.api import logger
class WecomAIQueueMgr:
"""企业微信智能机器人队列管理器"""
def __init__(self) -> None:
self.queues: Dict[str, asyncio.Queue] = {}
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
self.back_queues: Dict[str, asyncio.Queue] = {}
"""StreamID 到输出队列的映射 - 用于发送机器人响应"""
self.pending_responses: Dict[str, Dict[str, Any]] = {}
"""待处理的响应缓存,用于流式响应"""
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
"""获取或创建指定会话的输入队列
Args:
session_id: 会话ID
Returns:
输入队列实例
"""
if session_id not in self.queues:
self.queues[session_id] = asyncio.Queue()
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
return self.queues[session_id]
def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue:
"""获取或创建指定会话的输出队列
Args:
session_id: 会话ID
Returns:
输出队列实例
"""
if session_id not in self.back_queues:
self.back_queues[session_id] = asyncio.Queue()
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
return self.back_queues[session_id]
def remove_queues(self, session_id: str):
"""移除指定会话的所有队列
Args:
session_id: 会话ID
"""
if session_id in self.queues:
del self.queues[session_id]
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
if session_id in self.back_queues:
del self.back_queues[session_id]
logger.debug(f"[WecomAI] 移除输出队列: {session_id}")
if session_id in self.pending_responses:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
def has_queue(self, session_id: str) -> bool:
"""检查是否存在指定会话的队列
Args:
session_id: 会话ID
Returns:
是否存在队列
"""
return session_id in self.queues
def has_back_queue(self, session_id: str) -> bool:
"""检查是否存在指定会话的输出队列
Args:
session_id: 会话ID
Returns:
是否存在输出队列
"""
return session_id in self.back_queues
def set_pending_response(self, session_id: str, callback_params: Dict[str, str]):
"""设置待处理的响应参数
Args:
session_id: 会话ID
callback_params: 回调参数nonce, timestamp等
"""
self.pending_responses[session_id] = {
"callback_params": callback_params,
"timestamp": asyncio.get_event_loop().time(),
}
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]:
"""获取待处理的响应参数
Args:
session_id: 会话ID
Returns:
响应参数如果不存在则返回None
"""
return self.pending_responses.get(session_id)
def cleanup_expired_responses(self, max_age_seconds: int = 300):
"""清理过期的待处理响应
Args:
max_age_seconds: 最大存活时间(秒)
"""
current_time = asyncio.get_event_loop().time()
expired_sessions = []
for session_id, response_data in self.pending_responses.items():
if current_time - response_data["timestamp"] > max_age_seconds:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
def get_stats(self) -> Dict[str, int]:
"""获取队列统计信息
Returns:
统计信息字典
"""
return {
"input_queues": len(self.queues),
"output_queues": len(self.back_queues),
"pending_responses": len(self.pending_responses),
}
# 全局队列管理器实例
wecomai_queue_mgr = WecomAIQueueMgr()

View File

@@ -0,0 +1,166 @@
"""
企业微信智能机器人 HTTP 服务器
处理企业微信智能机器人的 HTTP 回调请求
"""
import asyncio
from typing import Dict, Any, Optional, Callable
import quart
from astrbot.api import logger
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_utils import WecomAIBotConstants
class WecomAIBotServer:
"""企业微信智能机器人 HTTP 服务器"""
def __init__(
self,
host: str,
port: int,
api_client: WecomAIBotAPIClient,
message_handler: Optional[
Callable[[Dict[str, Any], Dict[str, str]], Any]
] = None,
):
"""初始化服务器
Args:
host: 监听地址
port: 监听端口
api_client: API客户端实例
message_handler: 消息处理回调函数
"""
self.host = host
self.port = port
self.api_client = api_client
self.message_handler = message_handler
self.app = quart.Quart(__name__)
self._setup_routes()
self.shutdown_event = asyncio.Event()
def _setup_routes(self):
"""设置 Quart 路由"""
# 使用 Quart 的 add_url_rule 方法添加路由
self.app.add_url_rule(
"/webhook/wecom-ai-bot",
view_func=self.verify_url,
methods=["GET"],
)
self.app.add_url_rule(
"/webhook/wecom-ai-bot",
view_func=self.handle_message,
methods=["POST"],
)
async def verify_url(self):
"""验证回调 URL"""
args = quart.request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
echostr = args.get("echostr")
if not all([msg_signature, timestamp, nonce, echostr]):
logger.error("URL 验证参数缺失")
return "verify fail", 400
# 类型检查确保不为 None
assert msg_signature is not None
assert timestamp is not None
assert nonce is not None
assert echostr is not None
logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
return result, 200, {"Content-Type": "text/plain"}
async def handle_message(self):
"""处理消息回调"""
args = quart.request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
if not all([msg_signature, timestamp, nonce]):
logger.error("消息回调参数缺失")
return "缺少必要参数", 400
# 类型检查确保不为 None
assert msg_signature is not None
assert timestamp is not None
assert nonce is not None
logger.debug(
f"收到消息回调msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
)
try:
# 获取请求体
post_data = await quart.request.get_data()
# 确保 post_data 是 bytes 类型
if isinstance(post_data, str):
post_data = post_data.encode("utf-8")
# 解密消息
ret_code, message_data = await self.api_client.decrypt_message(
post_data, msg_signature, timestamp, nonce
)
if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
logger.error("消息解密失败,错误码: %d", ret_code)
return "消息解密失败", 400
# 调用消息处理器
response = None
if self.message_handler:
try:
response = await self.message_handler(
message_data, {"nonce": nonce, "timestamp": timestamp}
)
except Exception as e:
logger.error("消息处理器执行异常: %s", e)
return "消息处理异常", 500
if response:
return response, 200, {"Content-Type": "text/plain"}
else:
return "success", 200, {"Content-Type": "text/plain"}
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
return "内部服务器错误", 500
async def start_server(self):
"""启动服务器"""
logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
try:
await self.app.run_task(
host=self.host,
port=self.port,
shutdown_trigger=self.shutdown_trigger,
)
except Exception as e:
logger.error("服务器运行异常: %s", e)
raise
async def shutdown_trigger(self):
"""关闭触发器"""
await self.shutdown_event.wait()
async def shutdown(self):
"""关闭服务器"""
logger.info("企业微信智能机器人服务器正在关闭...")
self.shutdown_event.set()
def get_app(self):
"""获取 Quart 应用实例"""
return self.app

View File

@@ -0,0 +1,199 @@
"""
企业微信智能机器人工具模块
提供常量定义、工具函数和辅助方法
"""
import string
import random
import hashlib
import base64
import aiohttp
import asyncio
from Crypto.Cipher import AES
from typing import Any, Tuple
from astrbot.api import logger
# 常量定义
class WecomAIBotConstants:
"""企业微信智能机器人常量"""
# 消息类型
MSG_TYPE_TEXT = "text"
MSG_TYPE_IMAGE = "image"
MSG_TYPE_MIXED = "mixed"
MSG_TYPE_STREAM = "stream"
MSG_TYPE_EVENT = "event"
# 流消息状态
STREAM_CONTINUE = False
STREAM_FINISH = True
# 错误码
SUCCESS = 0
DECRYPT_ERROR = -40001
VALIDATE_SIGNATURE_ERROR = -40002
PARSE_XML_ERROR = -40003
COMPUTE_SIGNATURE_ERROR = -40004
ILLEGAL_AES_KEY = -40005
VALIDATE_APPID_ERROR = -40006
ENCRYPT_AES_ERROR = -40007
ILLEGAL_BUFFER = -40008
def generate_random_string(length: int = 10) -> str:
"""生成随机字符串
Args:
length: 字符串长度,默认为 10
Returns:
随机字符串
"""
letters = string.ascii_letters + string.digits
return "".join(random.choice(letters) for _ in range(length))
def calculate_image_md5(image_data: bytes) -> str:
"""计算图片数据的 MD5 值
Args:
image_data: 图片二进制数据
Returns:
MD5 哈希值(十六进制字符串)
"""
return hashlib.md5(image_data).hexdigest()
def encode_image_base64(image_data: bytes) -> str:
"""将图片数据编码为 Base64
Args:
image_data: 图片二进制数据
Returns:
Base64 编码的字符串
"""
return base64.b64encode(image_data).decode("utf-8")
def format_session_id(session_type: str, session_id: str) -> str:
"""格式化会话 ID
Args:
session_type: 会话类型 ("user", "group")
session_id: 原始会话 ID
Returns:
格式化后的会话 ID
"""
return f"wecom_ai_bot_{session_type}_{session_id}"
def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
"""解析格式化的会话 ID
Args:
formatted_session_id: 格式化的会话 ID
Returns:
(会话类型, 原始会话ID)
"""
parts = formatted_session_id.split("_", 3)
if (
len(parts) >= 4
and parts[0] == "wecom"
and parts[1] == "ai"
and parts[2] == "bot"
):
return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
return "user", formatted_session_id
def safe_json_loads(json_str: str, default: Any = None) -> Any:
"""安全地解析 JSON 字符串
Args:
json_str: JSON 字符串
default: 解析失败时的默认值
Returns:
解析结果或默认值
"""
import json
try:
return json.loads(json_str)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
return default
def format_error_response(error_code: int, error_msg: str) -> str:
"""格式化错误响应
Args:
error_code: 错误码
error_msg: 错误信息
Returns:
格式化的错误响应字符串
"""
return f"Error {error_code}: {error_msg}"
async def process_encrypted_image(
image_url: str, aes_key_base64: str
) -> Tuple[bool, str]:
"""下载并解密加密图片
Args:
image_url: 加密图片的URL
aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
Returns:
Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
status 为 False 时 data 是错误信息
"""
# 1. 下载加密图片
logger.info("开始下载加密图片: %s", image_url)
try:
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=15) as response:
response.raise_for_status()
encrypted_data = await response.read()
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
error_msg = f"下载图片失败: {str(e)}"
logger.error(error_msg)
return False, error_msg
# 2. 准备AES密钥和IV
if not aes_key_base64:
raise ValueError("AES密钥不能为空")
# Base64解码密钥 (自动处理填充)
aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
if len(aes_key) != 32:
raise ValueError("无效的AES密钥长度: 应为32字节")
iv = aes_key[:16] # 初始向量为密钥前16字节
# 3. 解密图片数据
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = cipher.decrypt(encrypted_data)
# 4. 去除PKCS#7填充 (Python 3兼容写法)
pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
if pad_len > 32: # AES-256块大小为32字节
raise ValueError("无效的填充长度 (大于32字节)")
decrypted_data = decrypted_data[:-pad_len]
logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
# 5. 转换为base64编码
base64_data = base64.b64encode(decrypted_data).decode("utf-8")
logger.info("图片已转换为base64编码编码后长度: %d", len(base64_data))
return True, base64_data

View File

@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
result = await asyncio.wait_for(
asyncio.shield(future), 60
) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py
@@ -182,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
return PlatformMetadata(
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
)
@override

View File

@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
return
logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,

View File

@@ -65,13 +65,16 @@ class AssistantMessageSegment:
role: str = "assistant"
def to_dict(self):
ret = {
ret: dict[str, str | list[dict]] = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
if self.tool_calls:
ret["tool_calls"] = self.tool_calls
tool_calls_dict = [
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
]
ret["tool_calls"] = tool_calls_dict
return ret
@@ -117,7 +120,14 @@ class ProviderRequest:
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
return (
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
f"image_count={len(self.image_urls or [])}, "
f"func_tool={self.func_tool}, "
f"contexts={self._print_friendly_context()}, "
f"system_prompt={self.system_prompt}, "
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
)
def __str__(self):
return self.__repr__()
@@ -297,6 +307,7 @@ class LLMResponse:
)
return ret
@dataclass
class RerankResult:
index: int

View File

@@ -4,7 +4,7 @@ import os
import asyncio
import aiohttp
from typing import Dict, List, Awaitable
from typing import Dict, List, Awaitable, Callable, Any
from astrbot import logger
from astrbot.core import sp
@@ -109,7 +109,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -132,7 +132,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> None:
"""添加函数调用工具
@@ -220,7 +220,7 @@ class FunctionToolManager:
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:

View File

@@ -7,7 +7,13 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .provider import (
Provider,
STTProvider,
TTSProvider,
EmbeddingProvider,
RerankProvider,
)
from .register import llm_tools, provider_cls_map
from ..persona_mgr import PersonaManager
@@ -38,7 +44,12 @@ class ProviderManager:
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map: dict[str, Provider] = {}
self.rerank_provider_insts: List[RerankProvider] = []
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -87,19 +98,31 @@ class ProviderManager:
)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
prov = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
prov, TTSProvider
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.SPEECH_TO_TEXT:
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov, STTProvider
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.CHAT_COMPLETION:
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov, Provider
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(self, provider_type: ProviderType, umo=None):
def get_using_provider(
self, provider_type: ProviderType, umo=None
) -> Provider | STTProvider | TTSProvider | None:
"""获取正在使用的提供商实例。
Args:
@@ -211,6 +234,8 @@ class ProviderManager:
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
@@ -303,12 +328,14 @@ class ProviderManager:
provider_metadata = provider_cls_map[provider_config["type"]]
try:
# 按任务实例化提供商
cls_type = provider_metadata.cls_type
if not cls_type:
logger.error(f"无法找到 {provider_metadata.type} 的类")
return
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -327,9 +354,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
@@ -345,7 +370,7 @@ class ProviderManager:
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
inst = cls_type(
provider_config,
self.provider_settings,
self.selected_default_persona,
@@ -366,13 +391,16 @@ class ProviderManager:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
@@ -388,6 +416,7 @@ class ProviderManager:
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.providers_config]
logger.debug(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
@@ -426,11 +455,17 @@ class ProviderManager:
)
if self.inst_map[provider_id] in self.provider_insts:
self.provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, Provider):
self.provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.stt_provider_insts:
self.stt_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, STTProvider):
self.stt_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] in self.tts_provider_insts:
self.tts_provider_insts.remove(self.inst_map[provider_id])
prov_inst = self.inst_map[provider_id]
if isinstance(prov_inst, TTSProvider):
self.tts_provider_insts.remove(prov_inst)
if self.inst_map[provider_id] == self.curr_provider_inst:
self.curr_provider_inst = None

View File

@@ -68,14 +68,15 @@ class Provider(AbstractProvider):
def get_keys(self) -> List[str]:
"""获得提供商 Key"""
return self.provider_config.get("key", [])
keys = self.provider_config.get("key", [""])
return keys or [""]
@abc.abstractmethod
def set_key(self, key: str):
raise NotImplementedError()
@abc.abstractmethod
def get_models(self) -> List[str]:
async def get_models(self) -> List[str]:
"""获得支持的模型列表"""
raise NotImplementedError()

View File

@@ -33,7 +33,7 @@ class ProviderAnthropic(Provider):
)
self.chosen_api_key: str = ""
self.api_keys: List = provider_config.get("key", [])
self.api_keys: List = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
@@ -70,9 +70,13 @@ class ProviderAnthropic(Provider):
{
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"],
"input": (
json.loads(tool_call["function"]["arguments"])
if isinstance(
tool_call["function"]["arguments"], str
)
else tool_call["function"]["arguments"]
),
"id": tool_call["id"],
}
)
@@ -355,9 +359,11 @@ class ProviderAnthropic(Provider):
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_data.split("base64,")[1]
if "base64," in image_data
else image_data,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
}
)

View File

@@ -0,0 +1,314 @@
import json
import asyncio
import aiohttp
import io
from typing import Dict, List, Any, AsyncGenerator
from astrbot.core import logger
class CozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
self.api_key = api_key
self.api_base = api_base
self.session = None
async def _ensure_session(self):
"""确保HTTP session存在"""
if self.session is None:
connector = aiohttp.TCPConnector(
ssl=False if self.api_base.startswith("http://") else True,
limit=100,
limit_per_host=30,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
timeout = aiohttp.ClientTimeout(
total=120, # 默认超时时间
connect=30,
sock_read=120,
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "text/event-stream",
}
self.session = aiohttp.ClientSession(
headers=headers, timeout=timeout, connector=connector
)
return self.session
async def upload_file(
self,
file_data: bytes,
) -> str:
"""上传文件到 Coze 并返回 file_id
Args:
file_data (bytes): 文件的二进制数据
Returns:
str: 上传成功后返回的 file_id
"""
session = await self._ensure_session()
url = f"{self.api_base}/v1/files/upload"
try:
file_io = io.BytesIO(file_data)
async with session.post(
url,
data={
"file": file_io,
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
response_text = await response.text()
logger.debug(
f"文件上传响应状态: {response.status}, 内容: {response_text}"
)
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
)
try:
result = await response.json()
except json.JSONDecodeError:
raise Exception(f"文件上传响应解析失败: {response_text}")
if result.get("code") != 0:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
logger.debug(f"[Coze] 图片上传成功file_id: {file_id}")
return file_id
except asyncio.TimeoutError:
logger.error("文件上传超时")
raise Exception("文件上传超时")
except Exception as e:
logger.error(f"文件上传失败: {str(e)}")
raise Exception(f"文件上传失败: {str(e)}")
async def download_image(self, image_url: str) -> bytes:
"""下载图片并返回字节数据
Args:
image_url (str): 图片的URL
Returns:
bytes: 图片的二进制数据
"""
session = await self._ensure_session()
try:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"下载图片失败,状态码: {response.status}")
image_data = await response.read()
return image_data
except Exception as e:
logger.error(f"下载图片失败 {image_url}: {str(e)}")
raise Exception(f"下载图片失败: {str(e)}")
async def chat_messages(
self,
bot_id: str,
user_id: str,
additional_messages: List[Dict] | None = None,
conversation_id: str | None = None,
auto_save_history: bool = True,
stream: bool = True,
timeout: float = 120,
) -> AsyncGenerator[Dict[str, Any], None]:
"""发送聊天消息并返回流式响应
Args:
bot_id: Bot ID
user_id: 用户ID
additional_messages: 额外消息列表
conversation_id: 会话ID
auto_save_history: 是否自动保存历史
stream: 是否流式响应
timeout: 超时时间
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/chat"
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
"auto_save_history": auto_save_history,
}
if additional_messages:
payload["additional_messages"] = additional_messages
params = {}
if conversation_id:
params["conversation_id"] = conversation_id
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
try:
async with session.post(
url,
json=payload,
params=params,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
# SSE
buffer = ""
event_type = None
event_data = None
async for chunk in response.content:
if chunk:
buffer += chunk.decode("utf-8", errors="ignore")
lines = buffer.split("\n")
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
if event_type and event_data:
yield {"event": event_type, "data": event_data}
event_type = None
event_data = None
elif line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_str = line[5:].strip()
if data_str and data_str != "[DONE]":
try:
event_data = json.loads(data_str)
except json.JSONDecodeError:
event_data = {"content": data_str}
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {str(e)}")
async def clear_context(self, conversation_id: str):
"""清空会话上下文
Args:
conversation_id: 会话ID
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/clear_context"
payload = {"conversation_id": conversation_id}
try:
async with session.post(url, json=payload) as response:
response_text = await response.text()
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
try:
return json.loads(response_text)
except json.JSONDecodeError:
raise Exception("Coze API 返回非JSON格式")
except asyncio.TimeoutError:
raise Exception("Coze API 请求超时")
except aiohttp.ClientError as e:
raise Exception(f"Coze API 请求失败: {str(e)}")
async def get_message_list(
self,
conversation_id: str,
order: str = "desc",
limit: int = 10,
offset: int = 0,
):
"""获取消息列表
Args:
conversation_id: 会话ID
order: 排序方式 (asc/desc)
limit: 限制数量
offset: 偏移量
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/list"
params = {
"conversation_id": conversation_id,
"order": order,
"limit": limit,
"offset": offset,
}
try:
async with session.get(url, params=params) as response:
response.raise_for_status()
return await response.json()
except Exception as e:
logger.error(f"获取Coze消息列表失败: {str(e)}")
raise Exception(f"获取Coze消息列表失败: {str(e)}")
async def close(self):
"""关闭会话"""
if self.session:
await self.session.close()
self.session = None
if __name__ == "__main__":
import os
import asyncio
async def test_coze_api_client():
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
try:
with open("README.md", "rb") as f:
file_data = f.read()
file_id = await client.upload_file(file_data)
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
bot_id=bot_id,
user_id="test_user",
additional_messages=[
{
"role": "user",
"content": json.dumps(
[
{"type": "text", "text": "这是什么"},
{"type": "file", "file_id": file_id},
],
ensure_ascii=False,
),
"content_type": "object_string",
},
],
stream=True,
):
print(f"Event: {event}")
finally:
await client.close()
asyncio.run(test_coze_api_client())

View File

@@ -0,0 +1,635 @@
import json
import os
import base64
import hashlib
from typing import AsyncGenerator, Dict
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://")
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: Dict[str, str] = {}
self.file_id_cache: Dict[str, Dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
else:
if data.startswith(("http://", "https://")):
# URL图片使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
else:
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self, image_url: str, session_id: str | None = None
) -> str:
"""下载图片并上传到 Coze返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {str(e)}")
raise Exception(f"处理图片失败: {str(e)}")
async def _process_context_images(
self, content: str | list, session_id: str
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data, is_base64=image_data.startswith("data:image/")
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id}
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data, session_id
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}..."
)
continue
processed_content.append(
{"type": "image", "file_id": file_id}
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {str(e)}")
if isinstance(content, str):
return content
else:
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
else:
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{"role": "system", "content": system_prompt, "content_type": "text"}
)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content, user_id
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
}
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url, user_id
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url, is_base64=True
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
# 本地文件
if os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {str(e)}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
else:
# 纯文本
if prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
}
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {str(e)}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {str(e)}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
else:
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {str(e)}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self, session_id: str, page: int = 1, page_size: int = 10
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()

View File

@@ -1,10 +1,22 @@
import os
import dashscope
import uuid
import asyncio
from dashscope.audio.tts_v2 import *
from ..provider import TTSProvider
import base64
import logging
import os
import uuid
from typing import Optional, Tuple
import aiohttp
import dashscope
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
try:
from dashscope.aigc.multimodal_conversation import MultiModalConversation
except (
ImportError
): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -26,16 +38,112 @@ class ProviderDashscopeTTSAPI(TTSProvider):
dashscope.api_key = self.chosen_api_key
async def get_audio(self, text: str) -> str:
model = self.get_model()
if not model:
raise RuntimeError("Dashscope TTS model is not configured.")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
self.synthesizer = SpeechSynthesizer(
model=self.get_model(),
os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model):
audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
else:
audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
if not audio_bytes:
raise RuntimeError(
"Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable."
)
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
with open(path, "wb") as f:
f.write(audio_bytes)
return path
def _call_qwen_tts(self, model: str, text: str):
if MultiModalConversation is None:
raise RuntimeError(
"dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models."
)
kwargs = {
"model": model,
"text": text,
"api_key": self.chosen_api_key,
"voice": self.voice or "Cherry",
}
if not self.voice:
logging.warning(
"No voice specified for Qwen TTS model, using default 'Cherry'."
)
return MultiModalConversation.call(**kwargs)
async def _synthesize_with_qwen_tts(
self, model: str, text: str
) -> Tuple[Optional[bytes], str]:
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes:
raise RuntimeError(
f"Audio synthesis failed for model '{model}'. {response}"
)
ext = ".wav"
return audio_bytes, ext
async def _extract_audio_from_response(self, response) -> Optional[bytes]:
output = getattr(response, "output", None)
audio_obj = getattr(output, "audio", None) if output is not None else None
if not audio_obj:
return None
data_b64 = getattr(audio_obj, "data", None)
if data_b64:
try:
return base64.b64decode(data_b64)
except (ValueError, TypeError):
logging.error("Failed to decode base64 audio data.")
return None
url = getattr(audio_obj, "url", None)
if url:
return await self._download_audio_from_url(url)
return None
async def _download_audio_from_url(self, url: str) -> Optional[bytes]:
if not url:
return None
timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
return await response.read()
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
logging.error(f"Failed to download audio from URL {url}: {e}")
return None
async def _synthesize_with_cosyvoice(
self, model: str, text: str
) -> Tuple[Optional[bytes], str]:
synthesizer = SpeechSynthesizer(
model=model,
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
audio = await asyncio.get_event_loop().run_in_executor(
None, self.synthesizer.call, text, self.timeout_ms
loop = asyncio.get_event_loop()
audio_bytes = await loop.run_in_executor(
None, synthesizer.call, text, self.timeout_ms
)
with open(path, "wb") as f:
f.write(audio)
return path
if not audio_bytes:
resp = synthesizer.get_response()
if resp and isinstance(resp, dict):
raise RuntimeError(
f"Audio synthesis failed for model '{model}'. {resp}".strip()
)
return audio_bytes, ".wav"
def _is_qwen_tts_model(self, model: str) -> bool:
model_lower = model.lower()
return "tts" in model_lower and model_lower.startswith("qwen")

View File

@@ -98,7 +98,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
# FishAudio的reference_id通常是32位十六进制字符串
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
pattern = r'^[a-fA-F0-9]{32}$'
pattern = r"^[a-fA-F0-9]{32}$"
return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict:

View File

@@ -3,7 +3,7 @@ import base64
import json
import logging
import random
from typing import Optional
from typing import Optional, List
from collections.abc import AsyncGenerator
from google import genai
@@ -60,7 +60,7 @@ class ProviderGoogleGenAI(Provider):
provider_settings,
default_persona,
)
self.api_keys: list = provider_config.get("key", [])
self.api_keys: List = super().get_keys()
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.timeout: int = int(provider_config.get("timeout", 180))
@@ -218,19 +218,21 @@ class ProviderGoogleGenAI(Provider):
response_modalities=modalities,
tools=tool_list,
safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
thinking_config=(
types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
),
24576,
),
24576,
),
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None,
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None
),
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
@@ -274,9 +276,11 @@ class ProviderGoogleGenAI(Provider):
if role == "user":
if isinstance(content, list):
parts = [
types.Part.from_text(text=item["text"] or " ")
if item["type"] == "text"
else process_image_url(item["image_url"])
(
types.Part.from_text(text=item["text"] or " ")
if item["type"] == "text"
else process_image_url(item["image_url"])
)
for item in content
]
else:

View File

@@ -38,7 +38,7 @@ class ProviderOpenAIOfficial(Provider):
default_persona,
)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.api_keys: List = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
@@ -99,12 +99,15 @@ class ProviderOpenAIOfficial(Provider):
for key in to_del:
del payloads[key]
model = payloads.get("model", "")
# 针对 qwen3 非 thinking 模型的特殊处理:非流式调用必须设置 enable_thinking=false
if "qwen3" in model.lower() and "thinking" not in model.lower():
extra_body["enable_thinking"] = False
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
model = payloads.get("model", "").lower()
# 针对 deepseek 模型的特殊处理deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
elif model == "deepseek-reasoner" and "tools" in payloads:
if model == "deepseek-reasoner" and "tools" in payloads:
del payloads["tools"]
completion = await self.client.chat.completions.create(
@@ -137,6 +140,12 @@ class ProviderOpenAIOfficial(Provider):
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
to_del = []
for key in payloads.keys():
if key not in self.default_params:

View File

@@ -1,4 +1,5 @@
import aiohttp
from astrbot import logger
from ..provider import RerankProvider
from ..register import register_provider_adapter
from ..entities import ProviderType, RerankResult
@@ -44,6 +45,11 @@ class VLLMRerankProvider(RerankProvider):
response_data = await response.json()
results = response_data.get("results", [])
if not results:
logger.warning(
f"Rerank API 返回了空的列表数据。原始响应: {response_data}"
)
return [
RerankResult(
index=result["index"],

View File

@@ -1,12 +1,12 @@
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
# This file was originally created to adapt to glm-4v-flash, which only supports one image in the context.
# It is no longer specifically adapted to Zhipu's models. To ensure compatibility, this
from ..register import register_provider_adapter
from astrbot.core.provider.entities import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("zhipu_chat_completion", " Chat Completion 提供商适配器")
@register_provider_adapter("zhipu_chat_completion", " Chat Completion 提供商适配器")
class ProviderZhipu(ProviderOpenAIOfficial):
def __init__(
self,
@@ -19,63 +19,3 @@ class ProviderZhipu(ProviderOpenAIOfficial):
provider_settings,
default_persona,
)
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
model = model or self.get_model()
# glm-4v-flash 只支持一张图片
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
new_context_query_ = []
for i in range(0, len(context_query) - 1, 2):
if isinstance(context_query[i].get("content", ""), list):
continue
new_context_query_.append(context_query[i])
new_context_query_.append(context_query[i + 1])
new_context_query_.append(context_query[-1]) # 保留最后一条记录
context_query = new_context_query_
logger.debug(context_query)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {"messages": context_query, **model_cfgs}
try:
llm_response = await self._query(payloads, func_tool)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 10
while retry_cnt > 0:
logger.warning(
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。"
)
try:
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
else:
raise e

View File

@@ -6,6 +6,7 @@ from astrbot.core.provider.provider import (
TTSProvider,
STTProvider,
EmbeddingProvider,
RerankProvider,
)
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
@@ -23,7 +24,7 @@ from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from typing import Awaitable, Any, Callable
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
@@ -103,9 +104,14 @@ class Context:
"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
def get_provider_by_id(
self, provider_id: str
) -> (
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
):
"""通过 ID 获取对应的 LLM Provider。"""
prov = self.provider_manager.inst_map.get(provider_id)
return prov
def get_all_providers(self) -> List[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
@@ -130,34 +136,43 @@ class Context:
Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if prov and not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
"""
获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.TEXT_TO_SPEECH,
umo=umo,
)
if prov and not isinstance(prov, TTSProvider):
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
return prov
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
"""
获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
return self.provider_manager.get_using_provider(
prov = self.provider_manager.get_using_provider(
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
if prov and not isinstance(prov, STTProvider):
raise ValueError("返回的 Provider 不是 STTProvider 类型")
return prov
def get_config(self, umo: str | None = None) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
@@ -245,7 +260,11 @@ class Context:
"""
def register_llm_tool(
self, name: str, func_args: list, desc: str, func_obj: Awaitable
self,
name: str,
func_args: list,
desc: str,
func_obj: Callable[..., Awaitable[Any]],
) -> None:
"""
为函数调用function-calling / tools-use添加工具。
@@ -267,9 +286,7 @@ class Context:
desc=desc,
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(
name, func_args, desc, func_obj, func_obj
)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
def unregister_llm_tool(self, name: str) -> None:
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
@@ -281,7 +298,7 @@ class Context:
command_name: str,
desc: str,
priority: int,
awaitable: Awaitable,
awaitable: Callable[..., Awaitable[Any]],
use_regex=False,
ignore_prefix=False,
):

View File

@@ -1,5 +1,7 @@
import re
import inspect
import types
import typing
from typing import List, Any, Type, Dict
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -14,6 +16,18 @@ class GreedyStr(str):
pass
def unwrap_optional(annotation) -> tuple:
"""去掉 Optional[T] / Union[T, None] / T|None返回 T"""
args = typing.get_args(annotation)
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1:
return (non_none_args[0],)
elif len(non_none_args) > 1:
return tuple(non_none_args)
else:
return ()
# 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter):
"""标准指令过滤器"""
@@ -32,11 +46,16 @@ class CommandFilter(HandlerFilter):
self.init_handler_md(handler_md)
self.custom_filter_list: List[CustomFilter] = []
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def print_types(self):
result = ""
for k, v in self.handler_params.items():
if isinstance(v, type):
result += f"{k}({v.__name__}),"
elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union:
result += f"{k}({v}),"
else:
result += f"{k}({type(v).__name__})={v},"
result = result.rstrip(",")
@@ -92,7 +111,8 @@ class CommandFilter(HandlerFilter):
# 没有 GreedyStr 的情况
if i >= len(params):
if (
isinstance(param_type_or_default_val, Type)
isinstance(param_type_or_default_val, (Type, types.UnionType))
or typing.get_origin(param_type_or_default_val) is typing.Union
or param_type_or_default_val is inspect.Parameter.empty
):
# 是类型
@@ -129,13 +149,42 @@ class CommandFilter(HandlerFilter):
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
origin = typing.get_origin(param_type_or_default_val)
if origin in (typing.Union, types.UnionType):
# 注解是联合类型
# NOTE: 目前没有处理联合类型嵌套相关的注解写法
nn_types = unwrap_optional(param_type_or_default_val)
if len(nn_types) == 1:
# 只有一个非 NoneType 类型
result[param_name] = nn_types[0](params[i])
else:
# 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。
# NOTE: 目前还没有做类型校验
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
raise ValueError(
f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"
)
return result
def get_complete_command_names(self):
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
self._cmpl_cmd_names = [
f"{parent} {cmd}" if parent else cmd
for cmd in [self.command_name] + list(self.alias)
for parent in self.parent_command_names or [""]
]
return self._cmpl_cmd_names
def equals(self, message_str: str) -> bool:
for full_cmd in self.get_complete_command_names():
if message_str == full_cmd:
return True
return False
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -145,18 +194,11 @@ class CommandFilter(HandlerFilter):
# 检查是否以指令开头
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
candidates = [self.command_name] + list(self.alias)
ok = False
for candidate in candidates:
for parent_command_name in self.parent_command_names:
if parent_command_name:
_full = f"{parent_command_name} {candidate}"
else:
_full = candidate
if message_str.startswith(f"{_full} ") or message_str == _full:
message_str = message_str[len(_full) :].strip()
ok = True
break
for full_cmd in self.get_complete_command_names():
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
ok = True
message_str = message_str[len(full_cmd) :].strip()
if not ok:
return False

View File

@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
def __init__(
self,
group_name: str,
alias: set = None,
parent_group: CommandGroupFilter = None,
alias: set | None = None,
parent_group: CommandGroupFilter | None = None,
):
self.group_name = group_name
self.alias = alias if alias else set()
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
self.custom_filter_list: List[CustomFilter] = []
self.parent_group = parent_group
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def add_sub_command_filter(
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
):
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
"""遍历父节点获取完整的指令名。
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
parent_cmd_names = (
self.parent_group.get_complete_command_names() if self.parent_group else []
)
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
for parent_cmd_name in parent_cmd_names:
for candidate in candidates:
result.append(parent_cmd_name + " " + candidate)
self._cmpl_cmd_names = result
return result
# 以树的形式打印出来
@@ -54,8 +61,8 @@ class CommandGroupFilter(HandlerFilter):
self,
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
prefix: str = "",
event: AstrMessageEvent = None,
cfg: AstrBotConfig = None,
event: AstrMessageEvent | None = None,
cfg: AstrBotConfig | None = None,
) -> str:
result = ""
for sub_filter in sub_command_filters:
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
return False
return True
def startswith(self, message_str: str) -> bool:
return message_str.startswith(tuple(self.get_complete_command_names()))
def equals(self, message_str: str) -> bool:
return message_str in self.get_complete_command_names()
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -105,18 +118,14 @@ class CommandGroupFilter(HandlerFilter):
if not self.custom_filter_ok(event, cfg):
return False
complete_command_names = self.get_complete_command_names()
if event.message_str.strip() in complete_command_names:
if self.equals(event.message_str.strip()):
tree = (
self.group_name
+ "\n"
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
)
raise ValueError(
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
+ tree
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
)
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False
return self.startswith(event.message_str)

View File

@@ -2,7 +2,6 @@ import enum
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from typing import Union
class PlatformAdapterType(enum.Flag):
@@ -19,6 +18,7 @@ class PlatformAdapterType(enum.Flag):
VOCECHAT = enum.auto()
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
SATORI = enum.auto()
MISSKEY = enum.auto()
ALL = (
AIOCQHTTP
| QQOFFICIAL
@@ -33,6 +33,7 @@ class PlatformAdapterType(enum.Flag):
| VOCECHAT
| WEIXIN_OFFICIAL_ACCOUNT
| SATORI
| MISSKEY
)
@@ -50,15 +51,19 @@ ADAPTER_NAME_2_TYPE = {
"vocechat": PlatformAdapterType.VOCECHAT,
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
"satori": PlatformAdapterType.SATORI,
"misskey": PlatformAdapterType.MISSKEY,
}
class PlatformAdapterTypeFilter(HandlerFilter):
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
self.type_or_str = platform_adapter_type_or_str
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
if isinstance(platform_adapter_type_or_str, str):
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
else:
self.platform_type = platform_adapter_type_or_str
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
adapter_name = event.get_platform_name()
if adapter_name in ADAPTER_NAME_2_TYPE:
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
return False

View File

@@ -8,6 +8,7 @@ from .star_handler import (
register_permission_type,
register_custom_filter,
register_on_astrbot_loaded,
register_on_platform_loaded,
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
@@ -26,6 +27,7 @@ __all__ = [
"register_permission_type",
"register_custom_filter",
"register_on_astrbot_loaded",
"register_on_platform_loaded",
"register_on_llm_request",
"register_on_llm_response",
"register_llm_tool",

Some files were not shown because too many files have changed in this diff Show More