Compare commits

...

353 Commits

Author SHA1 Message Date
Soulter
e83fc570a4 chore: bump version to 4.1.0 2025-09-13 13:31:49 +08:00
Yokami
e841b6af88 feat: 支持在 WebUI 自定义 OpenAI API extra_body 参数 (#2719)
* feat: 支持OPENAI系 模型的自定义标头,以解决qwen模型无法使用的问题

* fix: 修复AI说的问题

* fix: 布尔开关向右对齐
2025-09-13 13:23:49 +08:00
Dt8333
ea6f209557 fix: 修复LLM仍会调用已禁用的工具的问题 (#2729)
* fix: 修复LLM调用已禁用的工具

* feat: 修改工具禁用判断位置,提高效率
未设置可用工具时仍旧循环判断
设置可用工具后在获取工具时即判断
2025-09-12 21:36:10 +08:00
Zhalslar
9bfa726107 feat: 兼容指令名和第一个参数之间没有空格的情况 (#2650)
插件中@filter.command的指令在用户输入“命令+参数” 无空格隔开时无法处理,但只要稍微改动几行代码就可以兼容
2025-09-12 15:40:37 +08:00
shangxue
d24902c66d feat: 添加 --webui-dir 启动参数以支持指定 WebUI 构建文件目录 (#2680)
* Update main.py

* Update server.py

* Update main.py

* Update main.py

* Update main.py

* Update initial_loader.py

* Update server.py

* Update main.py

* chore: update webui_dir type hint and improve dashboard file check logic

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-12 15:15:29 +08:00
RC-CHN
72aea2d3f3 feat: 允许添加多个 tavily API Key 进行轮询 (#2725)
* feat: 允许添加多个tavily API Key进行轮询

* perf: 并发安全的从列表中获取并轮换Tavily API密钥

* fix: 自动迁移旧版 websearch_tavily_key 为列表格式并保存
2025-09-12 15:03:47 +08:00
RC-CHN
dc9612d564 fix: 修复自定义文转图模板更新版本后会被覆盖的问题 (#2677)
* perf: 更新模板管理逻辑,在data目录中管理用户自定义模板,优化热重载逻辑

* refactor: 优化模板管理逻辑,重构模板复制和初始化流程,增强用户模板管理功能

* chore:移除无用注释

* remove:移除了t2i部分中不会走到的异常

* style: format code

* fix: trim whitespace from template names in create, update, and delete operations

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-12 13:34:07 +08:00
Soulter
1770556d56 fix: 修复工具调用时的 content 内容在重新加载后没有显示在 webchat 的问题 (#2727) 2025-09-12 13:05:33 +08:00
Soulter
888fb84aee fix: 修复 WebChat 下,Agent 长时任务时,SSE 连接自动断开的问题 2025-09-12 13:04:27 +08:00
Soulter
d597fd056d fix: 修复知识库不能创建的问题 2025-09-11 17:27:57 +08:00
quirrel
dea0ab3974 fix: 解决插件页表格视图中,点击状态字段表头排序不起作用的问题 (#2714) 2025-09-11 16:20:33 +08:00
Soulter
da6facd7d7 docs: 修复开发者群组错误 2025-09-11 12:44:59 +08:00
Soulter
bb8ab5f173 docs: update readme 2025-09-11 10:40:30 +08:00
Soulter
ac8a541059 docs: remove message stat badge
Removed the old dynamic JSON badge for message volume.
2025-09-11 10:36:18 +08:00
Soulter
0e66771f0e docs: revise acknowledgments and add similar projects
Updated project acknowledgments and added links to similar open-source bot projects.
2025-09-11 10:35:11 +08:00
Soulter
d3a295a801 ci: add auto_assign.yml for auto PR reviewer assignment 2025-09-10 13:21:34 +08:00
shangxue
f2df771771 fix: 修复 Satori 适配器教程链接 (#2668)
* Update PlatformPage.vue

* Update PlatformPage.vue
2025-09-09 21:59:06 +08:00
dependabot[bot]
7b72cd87a5 chore(deps): bump the github-actions group with 2 updates (#2674)
Bumps the github-actions group with 2 updates: [actions/setup-python](https://github.com/actions/setup-python) and [actions/stale](https://github.com/actions/stale).


Updates `actions/setup-python` from 5 to 6
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v5...v6)

Updates `actions/stale` from 9 to 10
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v9...v10)

---
updated-dependencies:
- dependency-name: actions/setup-python
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/stale
  dependency-version: '10'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-09 08:46:04 +08:00
anka
9431efc6d1 feat: 增加 on_platform_loaded 钩子以在消息平台适配器实例化完成后触发 (#2651)
* feat⚒️: 增加平台加载时的钩子

* fix: 补充api

* fix: 只捕获Exception
2025-09-09 08:44:37 +08:00
Soulter
7c3f5431ba chore: bump version to 4.0.0 2025-09-07 21:19:19 +08:00
anka
d98cf16a4c feat: 增加根据qq号/群号主动发送消息的封装, 增加事件构造 (#2629)
* feat: 增加根据qq号/群号主动发送消息的封装, 增加事件构造

* fix: 增加不支持平台提示, 修正文档字符串

* chore: lint

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-07 17:04:29 +08:00
Soulter
2c3c3ae546 fix: 移除无用的调试日志以简化命令注册逻辑 2025-09-07 11:37:34 +08:00
Soulter
905eef48e3 feat: 增加 OneBot 服务 Token 为空时的安全提醒 (#2648) 2025-09-07 00:51:46 +08:00
RC-CHN
b31b520c7c feat: 支持管理 T2I 模版 (#2638)
* feat:添加t2i模板管理后端api,移除config.py中重复功能

* feat: 添加T2I模板管理功能前端,支持模板的创建、应用和重置

* refactor: 修复错误的保存逻辑,将t2i注册时打印路由信息部分移到基类实现

* remove:移除了路由注册时的打印

* chore: format code

* fix: update input variant from solo to outlined for better UI consistency

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-07 00:14:28 +08:00
shangxue
17aee086a3 feat: 添加 Satori 协议适配器支持 (#2633)
* Create satori_adapter.py

* Add files via upload

* Update default.py

* Update manager.py

* Update platform_adapter_type.py

* Update PlatformPage.vue

* Add files via upload

* Update default.py

* Update manager.py

* Update platform_adapter_type.py

* Update PlatformPage.vue

* Add files via upload

* Update default.py

* chore: format code

* feat: 修复 Image, Audio 的解析,修复 message_str 的解析

* perf: 增强鲁棒性

* feat: 添加 Satori 配置项描述,移除适配器默认配置

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-06 23:52:00 +08:00
Zhalslar
c1756e5767 fix: 修复组件 type 属性为枚举值 (#2628)
当前components.py中每个组件的type属性都是直接字符串赋值,IDE会爆红。
修正为使用本就定义好的ComponentType枚举类
用时修正多个组件中当url为空时convert_to_base64检查路径导致的报错
2025-09-06 19:22:49 +08:00
anka
2920279c64 Fix: 修正 QQ 群成员昵称获取 (#2626)
* feat: 修正群昵称获取

* fix: 增加兜底机制
2025-09-06 19:16:57 +08:00
Soulter
1f0f985b01 docs: update readme
Updated README to improve clarity and organization. Changed section titles and removed unnecessary content.
2025-09-06 16:21:33 +08:00
Soulter
0762c81633 Update Discord link in README.md 2025-09-06 11:49:05 +08:00
Soulter
28ef301ccc docs: update readme 2025-09-06 11:48:37 +08:00
Soulter
26c6a2950f 📦 release: bump version to v4.0.0-beta.5 2025-09-05 17:42:38 +08:00
Soulter
5082876de3 fix: 修复 v4.0.0 版本下可能无法得到 LLM 的响应的问题
closes: #2622
2025-09-05 17:20:29 +08:00
Soulter
e50e7ad3d5 fix: ensure deep copy of config_data before posting 2025-09-05 16:45:34 +08:00
卢小辉
45a4a6b6da feat: 给添加 edge_tts 新增 rate, volume, pitch 参数 (#2625)
* 修复python执行器文件上传qq提示参数错误问题,修改策略为本地url

* 给edge_tts 添加3个默认参数,方便通过ui配置
2025-09-05 15:23:21 +08:00
Soulter
02918b7267 perf: 增加 abconf_data 缓存,优化性能 2025-09-04 23:48:33 +08:00
Soulter
6c662a36c1 fix: 适配 qwen3 的 thinking 类模型
fixes: #2631
2025-09-04 20:26:52 +08:00
Soulter
b78fe3822a perf: 完善对 rerank model 的可用性检测 2025-09-04 15:46:23 +08:00
Soulter
35eda37e83 Merge remote-tracking branch 'origin/releases/v3.5.27' 2025-09-04 15:30:15 +08:00
Soulter
176a8e7067 chore: add no_proxy config item 2025-09-04 15:23:58 +08:00
Soulter
61d4f1fd4b 📦 release: v3.5.27 2025-09-04 15:01:44 +08:00
Soulter
121b68995e chore: update changelog 2025-09-04 14:34:23 +08:00
Soulter
d11f1d8dae perf: enhance update checks to consider pre-release versions 2025-09-04 14:33:19 +08:00
Soulter
c0ef2b5064 📦 release: v3.5.27 2025-09-04 13:56:47 +08:00
Soulter
2a7308363e fix: 下载 WebUI 时,明确版本号 2025-09-04 13:54:16 +08:00
Soulter
dc0c556f96 ci: build docker image 时同时 build webui,并放入 image 中 2025-09-04 13:42:26 +08:00
Soulter
ba2ee1c0aa fix: 初次下载 webui 构建文件时下载指定版本而非 latest 2025-09-04 13:27:55 +08:00
Zhalslar
0f8b550d68 fix: aiocqhttp优先使用session_id发送消息 (#2623)
* fix: aiocqhttp优先使用session_id发送消息

当前aiocqhttp依赖raw_message来发送消息,raw_message为空时也无法有效回退到用group_id或user_id来发送,更符合逻辑的应该:优先使用session_id(group_id or user_id),raw_message兜底

* Update aiocqhttp_message_event.py

* fix: validate session_id as integer and improve send_message docstring

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-04 11:34:40 +08:00
Soulter
ed1fc98821 Merge pull request #2621 from AstrBotDevs/fix/gemini-api-error-handle
Fix: 修复 e.message 为 None 时报错的问题和部分 lint error
2025-09-04 11:20:05 +08:00
Soulter
fa53b468fd fix: ensure function call name and args are not None before processing 2025-09-04 11:18:58 +08:00
Soulter
4e2533d320 feat: add pre-release check for Docker image tagging 2025-09-04 09:30:21 +08:00
Soulter
388ae49e55 fix: 修复 e.message 为 None 时报错的问题和一些 lint error 2025-09-03 22:25:18 +08:00
Soulter
f3f347dcba 📦 release: bump verstion to v4.0.0-beta.4 2025-09-03 13:29:20 +08:00
Soulter
655be3519c perf: 数据迁移完毕之后引导重启程序
closes: #2613
2025-09-03 13:21:56 +08:00
Soulter
06df2940af chore: change identifier description 2025-09-03 12:46:10 +08:00
Soulter
4149549e42 fix: KeyError arprompt 2025-09-03 12:45:34 +08:00
Soulter
da351991f8 📦 release: bump verstion to v4.0.0-beta.3 2025-09-03 01:01:48 +08:00
Soulter
3305152e50 fix: 修复当人格 ID 为中文时,不可保存的问题 2025-09-03 00:59:07 +08:00
Soulter
bea7bae674 fix: dict read 2025-09-03 00:56:41 +08:00
Soulter
45773d38ed 📦 release: bump verstion to v4.0.0-beta.2 2025-09-03 00:32:49 +08:00
Soulter
8d4c176314 fix: correct image_caption logic and remove redundant config call 2025-09-03 00:31:18 +08:00
Soulter
9ca5c87c4c fix: complete requirements.txt 2025-09-03 00:05:43 +08:00
Soulter
36a6f00e5f Merge pull request #2610 from AstrBotDevs/releases/4.0.0 (#2610)
Release: v4.0.0-beta.1
2025-09-02 23:47:21 +08:00
Soulter
e24a5b4cb5 Revert "Release: v4.0.0-beta.1 (#2509)" (#2609)
This reverts commit f88031b0c9.
2025-09-02 23:44:36 +08:00
Soulter
f88031b0c9 Release: v4.0.0-beta.1 (#2509)
* Refactor: using sqlmodel(sqlchemy+pydantic) as ORM framework and switch to async-based sqlite operation (#2294)

* stage

* stage

* refactor: using sqlchemy as ORM framework, switch to async-based sqlite operation

- using sqlmodel as ORM(based on sqlchemy and pydantic)
- add Persona, Preference, PlatformMessageHistory table

* fix: conversation

* fix: remove redundant explicit session.commit, and fix some type error

* fix: conversation context issue

* chore: remove comments

* chore: remove exclude_content param

* Fix: 当多个相同消息平台实例部署时上下文可能混乱(共享) (#2298)

* perf: update astrbot event session format, using platfrom id to ensure uniqueness

fixes: #1000

* fix: 更新 MessageSession 类以使用 platform_id 作为唯一标识符,并调整相关方法以确保一致性

* fix: 更新 MessageSession 文档以明确 platform_id 的赋值规则,并调整 get_platform 和 get_platform_inst 方法的返回类型

* Improve: 引入全新的人格管理模式以及重构函数工具管理器 (#2305)

* feat: add persona management

* refactor:  重构函数工具管理器,引入 ToolSet,并让 Persona 支持绑定 Tools

* feat: 更新 Persona 工具选择逻辑,支持全选和指定工具的切换

* feat: 更新 BaseDatabase 中的 persona 方法返回类型,支持返回 None

* fix: platform id

* feat: add support to sync mcp servers from ModelScope (#2313)

* fix: 修复访问令牌的空格问题

* chore: 移除 MCP 市场相关逻辑 (#2314)

* chore: 移除 MCP 市场相关路由

* Refactor: 重构配置文件管理,以支持更灵活的、会话粒度的(基于 umo part)配置文件隔离 (#2328)

* refactor: 重构配置文件管理,以支持更灵活的、基于 umo part 的配置文件隔离

* Refactor: 重构配置前端页面,新增数个配置项 (#2331)

* refactor: 重构配置前端页面,新增数个配置项

* feat: 完善多配置文件结构

* perf: 系统配置入口

* fix: normal config item list not display

* fix: 修复 axios 请求中的上下文引用问题

* chore: remove status checking in chat page

* fix: 修复 stage 在不同 pipeline 中被重复使用的问题和 persona 相关问题

* Feature: 增加图片转述提供商配置、支持用户自定义模型模态能力 (#2422)

* feat: 增加图片转述提供商配置、支持用户自定义模型模态能力

* fix: 修复 LLMRequestSubStage 中会话管理方法参数不一致的问题,简化方法调用

* Feature: 优化 WebSearch 的爬取网页速度并且支持使用 Tavily 作为搜索引擎 (#2427)

* feat: 优化了 websearch 的速度;支持 Tavily 作为搜索引擎

* fix: 优化日志记录格式,修复搜索结果处理中的索引和内容显示问题

* feat: 添加对话选中状态管理,优化默认对话加载逻辑

* feat: 支持通过解析URL 的方式导入网页数据到知识库 (#2280)

* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑

* feat:完成从url获取部分的UI

* feat: 添加从URL导入功能的组件

* fix: 优化导入结果处理,添加整体摘要和主题摘要的文件命名

* perf: 更新url导入选项添加默认值

* perf: 在导入url的部分配置项未启用时隐藏暂不使用的下拉框选项

* feat: 添加上传前提提示信息至导入url至知识库功能

* feat: 更新导入功能提示信息,添加上传状态通知

* fix: 优化url转知识库错误处理

* feat: 合并知识库的上传文件和 URL 标签页

* feat: 删除导入URL至知识库功能的相关组件

---------

Co-authored-by: Soulter <905617992@qq.com>

* feat: 添加条件显示逻辑以优化插件配置项的可见性管理 (#2433)

* Feature: 支持在 WebUI 配置文件页中配置默认知识库 (#2437)

* feat: 支持配置默认知识库

* chore: clean code

* refactor: 重构 Function Tool 管理并初步引入 Multi Agent 及 Agent Handsoff 机制  (#2454)

* stage

* refactor: 重构 Function Tool 管理并引入 multi agent handsoff 机制

- Updated `star_request.py` to use the global `call_handler` instead of context-specific calls.
- Modified `entities.py` to remove the dependency on `FunctionToolManager` and streamline the function tool handling.
- Refactored `func_tool_manager.py` to simplify the `FunctionTool` class and its methods, removing deprecated code and enhancing clarity.
- Adjusted `provider.py` to align with the new function tool structure, removing unnecessary type unions.
- Enhanced `star_handler.py` to support agent registration and tool association, introducing `RegisteringAgent` for better encapsulation.
- Updated `star_manager.py` to handle tool registration for agents, ensuring proper binding of handlers.
- Revised `main.py` in the web searcher package to utilize the new agent registration system for web search tools.

* chore: websearch

* perf: 减少嵌套

* chore: 移除未使用的 mcp 导入

* feat: 添加 WebUI 迁移助手以及相关迁移方法 (#2477)

* fix: 修复迁移对话时的一些问题

* feat: 增加工具使用模型能力选项

* feat: 添加知识库插件更新检查和更新功能

* perf: 调整 WebUI sidebar 顺序

* refactor: 重构 SharedPreference 类并采用数据库存储替换 json 存储 (#2482)

* perf: 使用 run_coroutine_threadsafe

Co-authored-by: Raven95676 <raven95676@gmail.com>

* Feature: 支持配置重排序模型(vLLM API 格式)用于 score 任务 (#2496)

* feat: 支持添加重排序模型(vLLM API 格式)用于 score 任务

* fix: update rerank API base URL to use localhost

* feat: 知识库支持配置重排序模型

* fix: remove debug print statement for reranked results in FaissVecDB

* fix: 移除知识库中的提示文本

* Feature: 支持在配置文件配置可用的插件组 (#2505)

* feat: 增加可用插件集合配置项

* remove: 旧版平台可用性配置

已经基于多配置文件实现。

* feat: 应用配置文件插件可用性配置

* perf: hoist if from if

* feat: llm_tool 装饰器返回值支持返回 mcp 库中 tool 的返回值类型(mcp.type.CallToolResult) (#2507)

* fix: add type definition for migrationDialog and ensure open method exists before calling

* chore: update project version to 4.0.0

* feat: 多 t2i 服务的随机负载均衡 (#2529)

* fix: bugfixes

* Improve: 扩大配置文件生效范围的自定义程度到会话粒度 (#2532)

* feat: 扩大配置文件生效范围的自定义程度

* perf: 冲突检测

* refactor: simplify config form validation and improve conflict message clarity

* chore: clean code

* feat: 插件配置支持多个快捷魔法配置项

* chore: 修复当自动更新 webchat title 时,history 被重置的问题

* bugfixes

* feat: add custom T2I template editor (#2581)

* perf: add option to clear provider selection in ProviderSelector component

* 📦 release: bump verstion to v4.0.0-beta.1

* chore: delete uv.lock

---------

Co-authored-by: RC-CHN <67079377+RC-CHN@users.noreply.github.com>
Co-authored-by: Raven95676 <raven95676@gmail.com>
2025-09-02 23:39:24 +08:00
Soulter
830151e6da chore: delete uv.lock 2025-09-02 23:31:51 +08:00
Soulter
1e14fba81a 📦 release: bump verstion to v4.0.0-beta.1 2025-09-02 23:27:55 +08:00
Soulter
7b8800c4eb perf: add option to clear provider selection in ProviderSelector component 2025-09-02 21:49:11 +08:00
Soulter
8f4625f53b Merge remote-tracking branch 'origin/master' into releases/4.0.0 2025-08-31 20:37:53 +08:00
Soulter
1e5f243edb 📦 release: v3.5.26 2025-08-31 20:25:05 +08:00
Soulter
e5eab2af34 fix: specify type for devCommits to enhance type safety 2025-08-31 20:16:18 +08:00
Soulter
c10973e160 fix: update getDevCommits function to support GitHub proxy and handle errors more gracefully 2025-08-31 20:06:37 +08:00
Soulter
b1e4bff3ec feat: 支持升级的同时更新到指定版本的 WebUI 2025-08-31 19:55:46 +08:00
Soulter
c1202cda63 fix: update GitHub release action to use correct commit SHA variable 2025-08-31 11:52:42 +08:00
Soulter
32d6cd7776 fix: update GitHub release action parameters for clarity 2025-08-31 11:50:26 +08:00
Soulter
2f78d30e93 feat: automated release from every commit in master branch 2025-08-31 11:42:28 +08:00
Junhua Don
33407c9f0d fix: 修复编辑会话名称窗口的圆角和左右边距问题 (#2583) 2025-08-31 11:12:25 +08:00
Soulter
d2d5ef1c5c feat: add custom T2I template editor (#2581) 2025-08-31 11:11:55 +08:00
RC-CHN
98d8eaee02 feat: 添加 no_proxy 配置支持以优化代理设置 (#2564) 2025-08-26 21:08:46 +08:00
ZvZPvz
10b9228060 feat: 调用 deepseek-reasoner 时自动移除 tools (#2531)
* 调用DeepSeek为思考模式时自动移除tools

* Update astrbot/core/provider/sources/openai_source.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update openai_source.py

* Update openai_source.py

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-08-26 20:44:56 +08:00
xiewoc
5872f1e017 feat: 支持官方 QQ 接口发送语音 (#2525)
* Update dingtalk_event.py

* Add files via upload

* Add files via upload

* Update qqofficial_platform_adapter.py

* Add files via upload

* chore: clean comments

* chore: clean code

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-26 20:40:59 +08:00
Soulter
5073f21002 bugfixes 2025-08-24 23:40:17 +08:00
Soulter
69aaf09ac8 chore: 修复当自动更新 webchat title 时,history 被重置的问题 2025-08-24 00:23:08 +08:00
Soulter
6e61ee81d8 feat: 插件配置支持多个快捷魔法配置项 2025-08-23 21:53:26 +08:00
Soulter
cfd05a8d17 chore: clean code 2025-08-23 21:46:59 +08:00
Raven95676
29845fcc4c feat: Gemini添加对LLMResponse的raw_completion支持 2025-08-23 14:14:25 +08:00
Soulter
e204b180a8 Improve: 扩大配置文件生效范围的自定义程度到会话粒度 (#2532)
* feat: 扩大配置文件生效范围的自定义程度

* perf: 冲突检测

* refactor: simplify config form validation and improve conflict message clarity
2025-08-22 19:31:55 +08:00
Soulter
563972fd29 fix: bugfixes 2025-08-22 17:41:06 +08:00
AkkoYK
cbe94b84fc feat: 为 FishAudio TTS 添加可选的 reference_id 直接指定功能 (#2513)
* 移除TTS提供商:FishAudio TTS的角色名称查询机制,改为直接使用参考模型ID

/// 修改内容 ///
- 移除复杂的角色查询逻辑
删除了 get_reference_id_by_character 方法
移除了通过角色名称搜索模型ID的API调用逻辑
- 简化配置字段
将 fishaudio-tts-character 字段替换为 fishaudio-tts-reference-id
设置默认值为可莉的模型ID:626bb6d3f3364c9cbc3aa6a67300a664
- 优化代码结构
直接在初始化时获取reference_id
简化请求生成逻辑,直接使用配置的模型ID

/// 修改原因 ///
避免同名冲突:不同模型可能使用相同的角色名称,导致获取错误的模型
提高性能:移除了额外的API查询步骤,减少延迟
增强可靠性:用户直接指定准确的模型ID,避免搜索失败的情况
简化维护:减少了代码复杂度,降低维护成本

/// 新的使用方式 ///
用户需要从 FishAudio 模型的详情页面/URL 中获取具体的模型ID(如 626bb6d3f3364c9cbc3aa6a67300a664),并在配置中直接填入 fishaudio-tts-reference-id 字段。

这个修改使得FishAudio TTS的配置更加直观和可靠,同时提升了系统的整体性能。

* Refactor: 添加FishAudio TTS reference_id格式验证

添加ID格式验证逻辑,防止无效的reference_id调用API失败。
验证32位十六进制格式并提供详细错误提示。

* Feat: 添加FishAudio TTS可选reference_id配置实现向前兼容

新增可选的reference_id字段,优先使用直接ID,未配置时回退到角色名称查询。
保持完全向前兼容,现有配置无需修改。
2025-08-22 16:55:07 +08:00
Soulter
aa6f73574d feat: 多 t2i 服务的随机负载均衡 (#2529) 2025-08-22 16:43:59 +08:00
Soulter
94f0419ef7 docs: update readme 2025-08-20 16:41:22 +08:00
Soulter
cefd2d7f49 chore: update project version to 4.0.0 2025-08-20 15:48:37 +08:00
Soulter
81e1e545fb fix: add type definition for migrationDialog and ensure open method exists before calling 2025-08-20 15:48:12 +08:00
Soulter
d516920e72 Merge remote-tracking branch 'origin/master' into releases/4.0.0 2025-08-20 15:43:54 +08:00
Soulter
2171372246 feat: llm_tool 装饰器返回值支持返回 mcp 库中 tool 的返回值类型(mcp.type.CallToolResult) (#2507) 2025-08-20 15:33:46 +08:00
Soulter
d2df4d0cce Feature: 支持在配置文件配置可用的插件组 (#2505)
* feat: 增加可用插件集合配置项

* remove: 旧版平台可用性配置

已经基于多配置文件实现。

* feat: 应用配置文件插件可用性配置

* perf: hoist if from if
2025-08-20 15:25:41 +08:00
Soulter
6ab90fc123 fix: 移除知识库中的提示文本 2025-08-20 11:27:02 +08:00
Soulter
1a84ebbb1e fix: remove debug print statement for reranked results in FaissVecDB 2025-08-19 17:57:16 +08:00
Soulter
c9c0352369 feat: 知识库支持配置重排序模型 2025-08-19 17:51:01 +08:00
Soulter
9903b028a3 Feature: 支持配置重排序模型(vLLM API 格式)用于 score 任务 (#2496)
* feat: 支持添加重排序模型(vLLM API 格式)用于 score 任务

* fix: update rerank API base URL to use localhost
2025-08-19 16:15:31 +08:00
Soulter
49def5d883 📦 release: v3.5.25 2025-08-19 01:32:24 +08:00
Soulter
6975525b70 feat: 添加预发布版本提醒和检测功能 2025-08-19 01:15:56 +08:00
Soulter
fbc4f8527b Merge remote-tracking branch 'origin/master' into releases/4.0.0 2025-08-19 00:53:36 +08:00
Soulter
90cb5a1951 fix: 当返回文本为空并且存在工具调用时错误地被终止事件,导致工具调用结果未被返回 (#2491)
fixes: #2448 #2379
2025-08-19 00:52:13 +08:00
Soulter
ac71d9f034 perf: 使用 run_coroutine_threadsafe
Co-authored-by: Raven95676 <raven95676@gmail.com>
2025-08-18 19:32:35 +08:00
Soulter
64bcbc9fc0 refactor: 重构 SharedPreference 类并采用数据库存储替换 json 存储 (#2482) 2025-08-18 19:12:26 +08:00
Soulter
9e7d46f956 perf: 调整 WebUI sidebar 顺序 2025-08-18 11:57:01 +08:00
Soulter
e911896cfb feat: 添加知识库插件更新检查和更新功能 2025-08-18 11:27:48 +08:00
Soulter
9c6d66093f feat: 增加工具使用模型能力选项 2025-08-18 10:37:10 +08:00
Soulter
b2e39b9701 fix: 修复迁移对话时的一些问题 2025-08-17 23:44:08 +08:00
Soulter
e95ad4049b feat: 添加 WebUI 迁移助手以及相关迁移方法 (#2477) 2025-08-17 23:24:30 +08:00
Soulter
1df49d1d6f refactor: 重构 Function Tool 管理并初步引入 Multi Agent 及 Agent Handsoff 机制 (#2454)
* stage

* refactor: 重构 Function Tool 管理并引入 multi agent handsoff 机制

- Updated `star_request.py` to use the global `call_handler` instead of context-specific calls.
- Modified `entities.py` to remove the dependency on `FunctionToolManager` and streamline the function tool handling.
- Refactored `func_tool_manager.py` to simplify the `FunctionTool` class and its methods, removing deprecated code and enhancing clarity.
- Adjusted `provider.py` to align with the new function tool structure, removing unnecessary type unions.
- Enhanced `star_handler.py` to support agent registration and tool association, introducing `RegisteringAgent` for better encapsulation.
- Updated `star_manager.py` to handle tool registration for agents, ensuring proper binding of handlers.
- Revised `main.py` in the web searcher package to utilize the new agent registration system for web search tools.

* chore: websearch

* perf: 减少嵌套

* chore: 移除未使用的 mcp 导入
2025-08-17 10:57:25 +08:00
Junhua Don
b71000e2f3 fix: 修复无法清空 http_proxy 代理的问题 (#2434)
* fix: 修复无法清空http_proxy代理的问题

* perf: 将“127.0.0.1”和“::1”添加到“no_proxy”以确保所有本地流量绕过代理。

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-08-17 10:49:23 +08:00
Soulter
47e6ed455e Feature: 支持在 WebUI 配置文件页中配置默认知识库 (#2437)
* feat: 支持配置默认知识库

* chore: clean code
2025-08-15 12:40:46 +08:00
Soulter
92592fb9d9 Merge remote-tracking branch 'origin/master' into releases/4.0.0 2025-08-14 23:55:08 +08:00
Soulter
02a9769b35 fix: 补充工具调用轮数上限配置 2025-08-14 23:51:22 +08:00
Soulter
7640f11bfc docs: update readme 2025-08-14 17:28:51 +08:00
Soulter
be8a0991ed feat: 添加条件显示逻辑以优化插件配置项的可见性管理 (#2433) 2025-08-14 14:56:31 +08:00
Soulter
9fa44dbcfa docs: update readme 2025-08-14 14:16:22 +08:00
RC-CHN
61aac9c80c feat: 支持通过解析URL 的方式导入网页数据到知识库 (#2280)
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑

* feat:完成从url获取部分的UI

* feat: 添加从URL导入功能的组件

* fix: 优化导入结果处理,添加整体摘要和主题摘要的文件命名

* perf: 更新url导入选项添加默认值

* perf: 在导入url的部分配置项未启用时隐藏暂不使用的下拉框选项

* feat: 添加上传前提提示信息至导入url至知识库功能

* feat: 更新导入功能提示信息,添加上传状态通知

* fix: 优化url转知识库错误处理

* feat: 合并知识库的上传文件和 URL 标签页

* feat: 删除导入URL至知识库功能的相关组件

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-08-14 14:01:11 +08:00
Soulter
60af83cfee feat: 添加对话选中状态管理,优化默认对话加载逻辑 2025-08-14 13:53:36 +08:00
Soulter
cf64e6c231 Merge remote-tracking branch 'origin/master' into releases/4.0.0 2025-08-14 13:36:19 +08:00
Copilot
2cae941bae Fix incomplete Gemini streaming responses in chat history (#2429)
* Initial plan

* Fix incomplete Gemini streaming responses in chat history

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-08-14 11:56:50 +08:00
Copilot
bc0784f41d fix: enable_thinking parameter for qwen3 models in non-streaming calls (#2424)
* Initial plan

* Fix ModelScope enable_thinking parameter for non-streaming calls

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Tighten enable_thinking condition to only Qwen/Qwen3 models

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* qwen3 model handle

* Update astrbot/core/provider/sources/openai_source.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-14 11:18:29 +08:00
Soulter
b711140f26 Feature: 优化 WebSearch 的爬取网页速度并且支持使用 Tavily 作为搜索引擎 (#2427)
* feat: 优化了 websearch 的速度;支持 Tavily 作为搜索引擎

* fix: 优化日志记录格式,修复搜索结果处理中的索引和内容显示问题
2025-08-14 10:52:35 +08:00
Copilot
c57d75e01a feat: add comprehensive GitHub Copilot instructions for AstrBot development (#2426)
* Initial plan

* Initial progress - completed repository exploration and dependency installation

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Complete copilot-instructions.md with comprehensive development guide

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Update copilot-instructions.md

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-13 23:31:28 +08:00
Soulter
1d766001bb Feature: 增加图片转述提供商配置、支持用户自定义模型模态能力 (#2422)
* feat: 增加图片转述提供商配置、支持用户自定义模型模态能力

* fix: 修复 LLMRequestSubStage 中会话管理方法参数不一致的问题,简化方法调用
2025-08-13 19:11:17 +08:00
Soulter
0759a11a85 fix: 修复 stage 在不同 pipeline 中被重复使用的问题和 persona 相关问题 2025-08-13 13:13:04 +08:00
Soulter
cb749a38ab chore: remove status checking in chat page 2025-08-13 10:45:50 +08:00
Soulter
369eab18ab Refactor: 重构配置文件管理,以支持更灵活的、会话粒度的(基于 umo part)配置文件隔离 (#2328)
* refactor: 重构配置文件管理,以支持更灵活的、基于 umo part 的配置文件隔离

* Refactor: 重构配置前端页面,新增数个配置项 (#2331)

* refactor: 重构配置前端页面,新增数个配置项

* feat: 完善多配置文件结构

* perf: 系统配置入口

* fix: normal config item list not display

* fix: 修复 axios 请求中的上下文引用问题
2025-08-13 09:18:49 +08:00
RC-CHN
73edeae013 perf: 优化hint渲染方式,为部分类型供应商添加默认的温度选项 (#2321)
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑

* perf:优化hint渲染方式,为部分类型供应商添加默认的温度选项
2025-08-12 21:53:06 +08:00
MUKAPP
7d46314dc8 fix: 修复注册文件时由于 file:/// 前缀,导致文件被误判为不存在的问题 (#2325)
fixes #2222
2025-08-12 21:47:31 +08:00
你们的饺子
d5a53a89eb fix: 修复插件的 terminate 无法被正常调用的问题 (#2352) 2025-08-12 21:41:19 +08:00
dependabot[bot]
a85bc510dd chore(deps): bump actions/checkout in the github-actions group (#2400)
Bumps the github-actions group with 1 update: [actions/checkout](https://github.com/actions/checkout).


Updates `actions/checkout` from 4 to 5
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-12 15:15:28 +08:00
Soulter
2beea7d218 📦 release: v3.5.24 2025-08-07 20:36:59 +08:00
Soulter
a93cd3dd5f feat: compshare provider 2025-08-07 20:25:45 +08:00
Soulter
6c1f540170 chore: 移除 MCP 市场相关路由 2025-08-04 19:21:27 +08:00
Soulter
d026a9f009 chore: 移除 MCP 市场相关逻辑 (#2314) 2025-08-04 17:40:21 +08:00
Soulter
a8e7dadd39 fix: 修复访问令牌的空格问题 2025-08-04 17:26:22 +08:00
Soulter
2f8d921adf feat: add support to sync mcp servers from ModelScope (#2313) 2025-08-04 17:24:07 +08:00
Soulter
0c6e526f94 fix: platform id 2025-08-04 13:10:37 +08:00
Soulter
b1e3018b6b Improve: 引入全新的人格管理模式以及重构函数工具管理器 (#2305)
* feat: add persona management

* refactor:  重构函数工具管理器,引入 ToolSet,并让 Persona 支持绑定 Tools

* feat: 更新 Persona 工具选择逻辑,支持全选和指定工具的切换

* feat: 更新 BaseDatabase 中的 persona 方法返回类型,支持返回 None
2025-08-04 00:56:26 +08:00
Soulter
87f05fce66 Fix: 当多个相同消息平台实例部署时上下文可能混乱(共享) (#2298)
* perf: update astrbot event session format, using platfrom id to ensure uniqueness

fixes: #1000

* fix: 更新 MessageSession 类以使用 platform_id 作为唯一标识符,并调整相关方法以确保一致性

* fix: 更新 MessageSession 文档以明确 platform_id 的赋值规则,并调整 get_platform 和 get_platform_inst 方法的返回类型
2025-08-02 21:38:55 +08:00
Soulter
1b37530c96 Merge remote-tracking branch 'origin/master' into releases/4.0.0 2025-08-02 20:14:18 +08:00
Soulter
db4d02c2e2 docs: add 1panel deployment method 2025-08-02 19:01:49 +08:00
Soulter
fd7811402b fix: 添加对 metadata 中 description 字段的支持,确保元数据完整性
fixes: #2245
2025-08-02 16:01:10 +08:00
你们的饺子
eb0325e627 fix: 修复了 OpenAI 类型的 LLM 空内容响应导致的无法解析 completion 的错误。 (#2279) 2025-08-02 15:46:11 +08:00
Soulter
842c3c8ea9 Refactor: using sqlmodel(sqlchemy+pydantic) as ORM framework and switch to async-based sqlite operation (#2294)
* stage

* stage

* refactor: using sqlchemy as ORM framework, switch to async-based sqlite operation

- using sqlmodel as ORM(based on sqlchemy and pydantic)
- add Persona, Preference, PlatformMessageHistory table

* fix: conversation

* fix: remove redundant explicit session.commit, and fix some type error

* fix: conversation context issue

* chore: remove comments

* chore: remove exclude_content param
2025-08-02 15:44:00 +08:00
IGCrystal
8b4b04ec09 fix(i18n): add missing noTemplates key (#2292) 2025-08-02 14:16:59 +08:00
Larch-C
9f32c9280f chore: update and rename PLUGIN_PUBLISH.md to PLUGIN_PUBLISH.yml (#2289) 2025-08-02 14:16:19 +08:00
yrk111222
4fcd09cfa8 feat: add ModelScope API support (#2230)
* add ModelScope API support

* update
2025-08-02 14:14:08 +08:00
Misaka Mikoto
7a8d65d37d feat: add plugins local cache and remote file MD5 validation (#2211)
* 修改openai的嵌入模型默认维度为1024

* 为插件市场添加本地缓存
- 优先使用api获取,获取失败时则使用本地缓存
- 每次获取后会更新本地缓存
- 如果获取结果为空,判定为获取失败,使用本地缓存
- 前端页面添加刷新按钮,用于手动刷新本地缓存

* feat: 增强插件市场缓存机制,支持MD5校验以确保数据有效性

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-08-02 14:03:53 +08:00
Raven95676
23129a9ba2 Merge branch 'releases/3.5.23' 2025-07-26 16:49:38 +08:00
Raven95676
7f791e730b fix: changelogs 2025-07-26 16:49:05 +08:00
Raven95676
f7e296b349 Merge branch 'releases/3.5.23' 2025-07-26 16:34:30 +08:00
Raven95676
712d4acaaa release: v3.5.23 2025-07-26 16:32:06 +08:00
Raven95676
74a5c01f21 refactor: remove code and documentation references related to gewechat 2025-07-26 14:19:17 +08:00
Raven95676
3ba8724d77 Merge branch 'master' into dev 2025-07-26 14:02:05 +08:00
鸦羽
6313a7d8a9 Merge pull request #2221 from Raven95676/fix/axios-dependency
fix: update axios version range for vulnerability fix
2025-07-24 18:34:14 +08:00
Raven95676
432a3f520c fix: update axios version range for vulnerability fix 2025-07-24 18:28:02 +08:00
Soulter
191b3e42d4 feat: implement log history retrieval and improve log streaming handling (#2190) 2025-07-23 23:36:08 +08:00
Misaka Mikoto
a27f05fcb4 chore: 修改 OpenAI 嵌入模型提供商默认向量维度为1024 (#2209) 2025-07-23 23:35:04 +08:00
Soulter
2f33e0b873 chore: remove adapters of wechat personal account 2025-07-23 10:51:42 +08:00
Soulter
f0359467f1 chore: remove adapters of wechat personal account 2025-07-23 10:50:43 +08:00
Soulter
d1db8cf2c8 chore: remove adapters of wechat personal account 2025-07-23 10:48:58 +08:00
Soulter
b1985ed2ce Merge branch 'dev' 2025-07-23 00:38:08 +08:00
Gao Jinzhe
140ddc70e6 feat: 使用会话锁保证分段回复时的消息发送顺序 (#2130)
* 优化分段消息发送逻辑,为分段消息添加消息队列

* 删除了不必要的代码

* style: code quality

* 将消息队列机制重构为会话锁机制

* perf: narrow the lock scope

* refactor: replace get_lock with async context manager for session locks

* refactor: optimize session lock management with defaultdict

---------

Co-authored-by: Soulter <905617992@qq.com>
Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-07-23 00:37:29 +08:00
Soulter
d7fd616470 style: code quality 2025-07-21 17:04:29 +08:00
Soulter
3ccbef141e perf: extension ui 2025-07-21 15:16:49 +08:00
Soulter
e92fbb0443 feat: add ProxySelector component for GitHub proxy configuration and connection testing (#2185) 2025-07-21 15:05:49 +08:00
Soulter
bd270aed68 fix: handle event construction errors in message reply processing 2025-07-20 22:52:14 +08:00
Soulter
28d7864393 perf: tool use page UI (#2182)
* perf: tool use UI

* fix: update background color of item cards in ToolUsePage
2025-07-20 20:24:03 +08:00
RC-CHN
b5d8173ee3 feat: add a file uplod button in WebChat page (#2136)
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑
2025-07-20 16:02:28 +08:00
Soulter
17d62a9af7 refactor: mcp server reload mechanism (#2161)
* refactor: mcp server reload mechanism

* fix: wait for client events

* fix: all other mcp servers are terminated when disable selected server

* fix: resolve type hinting issues in MCPClient and FuncCall methods

* perf: optimize mcp server loaders

* perf: improve MCP client connection testing

* perf: improve error message

* perf: clean code

* perf: increase default timeout for MCP connection and reset dialog message on close

---------

Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-07-20 15:53:13 +08:00
Soulter
d89fb863ed fix: improve logging and error message details in LLMRequestSubStage 2025-07-18 16:13:27 +08:00
Soulter
a21ad77820 Merge pull request #2146 from Raven95676/fix/mcp
fix: 修复MCP导致的持续占用100% CPU
2025-07-18 13:04:50 +08:00
Raven95676
f86c8e8cab perf: ensure MCP client termination in cleanup process 2025-07-17 23:17:23 +08:00
Raven95676
cb12cbdd3d fix: managing MCP connections with AsyncExitStack 2025-07-16 23:44:51 +08:00
Soulter
6661fa996c fix: audio block does not display 2025-07-14 22:20:03 +08:00
Soulter
c19bca798b fix: xfyun model tool use error workaround
fixes: #1359
2025-07-14 22:07:33 +08:00
Soulter
8f98b411db Merge pull request #2129 from AstrBotDevs/perf-refine-webui-chatpage
Improve: WebUI ChatPage markdown code block background
2025-07-14 21:49:18 +08:00
Soulter
a8aa03847e feat: enhance theme customization with new background properties and markdown styling 2025-07-14 21:47:25 +08:00
Soulter
1bfd747cc6 perf: add system_prompt to payload_vars in dify text_chat method 2025-07-14 11:00:13 +08:00
Soulter
ae06d945a7 Merge pull request #2054 from RC-CHN/master
Feature: Add provider_type field for ProviderMetadata and improve provider availabiliby test
2025-07-13 17:38:22 +08:00
Soulter
9f41d5f34d Merge remote-tracking branch 'origin/master' into RC-CHN/master 2025-07-13 17:35:53 +08:00
Soulter
ef61c52908 fix: remove non-existent Response field 2025-07-13 17:33:13 +08:00
Soulter
d8842ef274 perf: code quality 2025-07-13 17:27:40 +08:00
Soulter
c88fdaf353 Merge pull request #1949 from advent259141/Astrbot_session_manage
[Feature] 支持在 WebUI 上管理会话
2025-07-13 17:23:52 +08:00
Soulter
af295da871 chore: remove /mcp command 2025-07-13 17:11:02 +08:00
Soulter
083235a2fe feat: enhance session management page with tooltips and layout adjustments 2025-07-13 17:06:15 +08:00
Soulter
2a3a5f7eb2 perf: refine session management page ui 2025-07-13 16:57:36 +08:00
Soulter
77c48f280f fix: session management paginator error 2025-07-13 16:36:25 +08:00
Soulter
0ee1eb2f9f chore: remove useless file 2025-07-13 16:30:00 +08:00
Soulter
c2b20365bb Merge pull request #2097 from SheffeyG/fix-status-checking
Fix: add status checking for embedding model providers
2025-07-13 16:19:37 +08:00
Soulter
cfdc7e4452 fix: add debug logging for provider request handling in LLMRequestSubStage
fixes: #2104
2025-07-13 16:12:48 +08:00
Soulter
2363f61aa9 chore: remove 'obvious_hint' fields from configuration metadata and remove some deprecated config 2025-07-13 16:03:47 +08:00
Soulter
557ac6f9fa Merge pull request #2112 from AstrBotDevs/perf-provider-logo
Improve: WebUI provider logo display
2025-07-13 15:34:19 +08:00
Soulter
a49b871cf9 fix: update Azure provider icon URL in getProviderIcon method 2025-07-13 15:33:47 +08:00
Soulter
a0d6b3efba perf: improve provider logo display in webui 2025-07-13 15:27:53 +08:00
Gao Jinzhe
6cabf07bc0 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-07-13 00:23:29 +08:00
advent259141
a15444ee8c 移除了mcp会话级的启停,增加了批量设置的选项,对相关问题进行了修复 2025-07-13 00:15:21 +08:00
Soulter
ceb5f5669e fix: update active reply bot prefix in logging for clarity 2025-07-13 00:12:31 +08:00
Gao Jinzhe
25b75e05e4 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-07-12 22:25:20 +08:00
sheffey
4d214bb5c1 check general numbers type instead 2025-07-11 18:36:47 +08:00
sheffey
7cbaed8c6c fix: add status checking for embedding model providers 2025-07-11 18:36:40 +08:00
Soulter
2915fdf665 release: v3.5.22 2025-07-11 12:29:26 +08:00
Soulter
a66c385b08 fix: deadlock when docker is not available 2025-07-11 12:27:49 +08:00
Raven95676
4dace7c5d8 chore: format code 2025-07-11 11:23:53 +08:00
Soulter
8ebf087dbf chore: optimize codes 2025-07-10 23:28:00 +08:00
Soulter
2fa8bda5bb chore: ruff lint 2025-07-10 23:23:29 +08:00
Soulter
a5ae833945 📦 release: v3.5.21 2025-07-10 17:46:36 +08:00
Soulter
d21d42b312 chore: update icon URL for 302.AI to use color version 2025-07-10 17:44:11 +08:00
Soulter
78575f0f0a fix: failed to delete conversation in webchat
fixes: #2071
2025-07-10 17:04:34 +08:00
Soulter
8ccd292d16 Merge pull request #2082 from AstrBotDevs/fix-webchat-segment-reply
fix: 修复 WebChat 下可能消息错位的问题
2025-07-10 17:00:14 +08:00
Soulter
2534f59398 chore: remove debug print statement from chat route 2025-07-10 16:59:58 +08:00
Soulter
5c60dbe2b1 fix: 修复 WebChat 下可能消息错位的问题 2025-07-10 16:52:16 +08:00
Soulter
c99ecde15f Merge pull request #2078 from AstrBotDevs/fix-webchat-image-cannot-render
Fix: webchat cannot render image and audio image normally
2025-07-10 11:57:50 +08:00
Soulter
219f3403d9 fix: webchat cannot render image and audio image normally 2025-07-10 11:51:47 +08:00
Soulter
00f417bad6 Merge pull request #2073 from Raven95676/fix/register_star
fix: 提升兼容性,并尽可能避免数据竞争
2025-07-10 11:03:57 +08:00
Soulter
81649f053b perf: improve log 2025-07-10 10:58:56 +08:00
Raven95676
e5bde50f2d fix: 提升兼容性,并尽可能避免数据竞争 2025-07-09 22:39:30 +08:00
Raven95676
0321e00b0d perf: 移除nh3 2025-07-09 20:32:14 +08:00
Soulter
09528e3292 docs: add model providers 2025-07-09 14:18:59 +08:00
Soulter
e7412a9cbf docs: add model providers 2025-07-09 14:17:39 +08:00
Soulter
01efe5f869 📦 release: v3.5.20 2025-07-09 13:35:44 +08:00
Soulter
28a178a55c Merge pull request #2067 from AstrBotDevs/refactor-aiocqhttp-send-message
Fix: active message cannot handle forward type message properly in aiocqhttp adapter
2025-07-09 13:23:08 +08:00
Soulter
88f130014c perf: streamline message dispatching logic in AiocqhttpMessageEvent 2025-07-09 12:10:18 +08:00
Soulter
af258c590c Merge pull request #2068 from AstrBotDevs/fix-tool-call-result-wrongly-sent
Fix: 修复工具调用被错误地发出到了消息平台上
2025-07-09 12:02:07 +08:00
Soulter
b0eb5733be Merge pull request #2065 from AstrBotDevs/fix-plugin-metadata-load
Improve: add fallback for missing 'desc' in plugin metadata
2025-07-09 12:01:06 +08:00
Soulter
fe35bfba37 Merge pull request #2064 from uersula/fix-image-removal-flag-logic
Fix: 移除 _remove_image_from_context中的flag逻辑
2025-07-09 12:00:30 +08:00
advent259141
7cfbc4ab8f 增加了针对整个会话启停的开关 2025-07-09 11:58:52 +08:00
Soulter
7a9d4f0abd fix: 修复工具调用被错误地发出到了消息平台上
fixes: #2060
2025-07-09 11:43:25 +08:00
Soulter
6f6a5b565c fix: active message cannot handle forward type message properly in aiocqhttp adapter 2025-07-09 11:19:32 +08:00
Soulter
e57deb873c perf: add fallback for missing 'desc' in plugin metadata and improve error logging 2025-07-09 10:47:03 +08:00
Gao Jinzhe
0f692b1608 Merge branch 'master' into Astrbot_session_manage 2025-07-09 10:13:51 +08:00
uersula
8c03e79f99 Fix: Remove buggy flag logic in _remove_image_from_context 2025-07-08 23:01:11 +08:00
Soulter
71290f0929 Merge pull request #2061 from AstrBotDevs/feat-handle-image-in-quote-message
Feature: 支持对引用消息中的图片内容进行理解
2025-07-08 22:11:17 +08:00
Soulter
22364ef7de feat: 支持对引用消息中的图片内容进行理解
fixes: #2056
2025-07-08 22:08:40 +08:00
Ruochen
2cc1eb1abc feat:实现了speech_to_text类型的供应商可用性检查 2025-07-08 21:55:31 +08:00
RC-CHN
90dbcbb4e2 Merge branch 'AstrBotDevs:master' into master 2025-07-08 21:28:50 +08:00
Ruochen
66503d58be feat:实现了text_to_speech类型的供应商可用性测试 2025-07-08 17:52:22 +08:00
Ruochen
8e10f0ce2b feat:实现了embedding类型的供应商可用性检查 2025-07-08 16:51:57 +08:00
Soulter
f51f510f2e perf: enhance date handle in reminder
fixes: #1901
2025-07-08 16:33:46 +08:00
Ruochen
c44f085b47 fix:对非文本生成类供应商暂时跳过测试 2025-07-08 16:32:39 +08:00
RC-CHN
a35f36eeaf Merge branch 'AstrBotDevs:master' into master 2025-07-08 15:34:19 +08:00
Ruochen
14564c392a feat:meta方法增加provider_type字段 2025-07-08 15:33:02 +08:00
Soulter
76e05ea749 Merge pull request #2022 from AstrBotDevs/deprecate/register_star-decorator
[Deprecation] 弃用register_star装饰器
2025-07-08 11:57:28 +08:00
Soulter
ab599dceed Merge branch 'master' into deprecate/register_star-decorator 2025-07-08 11:52:33 +08:00
Soulter
4c37604445 perf: only output deprecation warning once for @register_star decorator 2025-07-08 11:50:55 +08:00
Soulter
bb74018d19 Merge pull request #1998 from diudiu62/feat-wechatpadpro-adapter
增加监听wechatpadpro消息平台的事件
2025-07-08 11:40:13 +08:00
Soulter
575289e5bc feat: complete platform adapter types and update mapping 2025-07-08 11:39:42 +08:00
Soulter
e89da2a7b4 Merge pull request #2035 from cclauss/patch-1
pytest recommendation: `pip install --editable .`
2025-07-08 11:35:34 +08:00
Soulter
bd34959f68 📦 release: v3.5.19 2025-07-08 01:34:08 +08:00
Soulter
622dcf8fd5 fix: 通过指令选择提供商重启后失效 2025-07-08 01:24:19 +08:00
Soulter
9e315739b7 Merge pull request #2051 from AstrBotDevs/perf-ui
Improve: 改善 WebUI 效果
2025-07-08 00:35:52 +08:00
Soulter
7b01adc5df perf: better webui 2025-07-08 00:33:22 +08:00
Soulter
432fc47443 feat: add 302.ai llm provider 2025-07-07 23:01:28 +08:00
Soulter
d8fba44c5e Merge pull request #2049 from uersula/fix/keyerror-in-recovery-handler
Fix: 防止错误恢复机制_remove_image_from_context发生KeyError
2025-07-07 22:13:43 +08:00
Soulter
e29d3d8c01 Merge pull request #2043 from Zhenyi-Wang/master
fix(wechatpadpro): 修复授权码提取逻辑以兼容新旧接口格式
2025-07-07 22:10:20 +08:00
uersula
e678413214 Fix: Prevent KeyError in _remove_image_from_context 2025-07-07 02:30:50 +08:00
Soulter
eaa9d9d087 Merge pull request #2027 from IGCrystal/Branch-2
🐞 fix(WebUI): 解决XSS注入的问题
2025-07-06 18:13:40 +08:00
Soulter
9e3cc076b7 🐞 fix(ReadmeDialog): add variant attribute to close button for consistency 2025-07-06 18:13:00 +08:00
IGCrystal
3bb01fa52c feat(ChatPage): 添加图像预览 2025-07-06 18:08:17 +08:00
IGCrystal
008e49d144 🎈 perf: 优化音频附件的显示 2025-07-06 18:08:17 +08:00
IGCrystal
4e275384b0 🐞 fix(VerticalHeader): 允许HTML渲染 2025-07-06 18:08:17 +08:00
IGCrystal
63ec99f67a 🐞 fix: 添加不存在的翻译键 2025-07-06 18:08:17 +08:00
IGCrystal
14a8bb57df 🐞 fix(WebUI): 解决XSS注入的问题 2025-07-06 18:08:17 +08:00
Soulter
7512bfc710 fix: update user message bubble styling for improved appearance 2025-07-06 18:06:28 +08:00
Soulter
3c3b6dadc3 Merge pull request #2037 from AstrBotDevs/fix/tool_call_result
fix: direct send tool_call_result
2025-07-06 18:05:59 +08:00
Soulter
cd722a0e39 fix: handle direct tool call results 2025-07-06 18:04:46 +08:00
Soulter
a1b5d0a100 Merge remote-tracking branch 'origin/master' into fix/tool_call_result 2025-07-06 17:47:09 +08:00
Raven95676
69d3ae709c fix: direct send tool_call_result 2025-07-06 17:45:07 +08:00
Soulter
67ef993d61 fix: webchat message bubble style 2025-07-06 17:21:57 +08:00
Soulter
20f49890ad fix: provider selection for updating webchat title 2025-07-06 17:18:37 +08:00
Zhenyi Wang
3e4917f0a1 refactor: 重构 wechatpadpro 授权码生成并增强安全性
- 将 generate_auth_key 方法中的授权码提取逻辑重构为新的辅助方法 _extract_auth_key ,以提高代码的可读性和可测试性。
- 在访问 data.get('authKeys') 之前添加 isinstance(data, dict) 检查,以防止潜在的 AttributeError 。
- 移除了 auth_key 的明文日志记录,以避免敏感信息泄露。
- 在生成新密钥之前,将 self.auth_key 初始化为 None ,以避免在失败时保留旧值。
2025-07-06 16:34:55 +08:00
Soulter
99ee75aec6 Merge pull request #2029 from jiongjiongJOJO/master
fix: 增加演示模式下校验插件开启/关闭/安装指令
2025-07-06 16:24:02 +08:00
Zhenyi Wang
1674653a42 fix(wechatpadpro): 修复授权码提取逻辑以兼容新旧接口格式
新接口返回多了一层authKeys字段,同时兼容二者
2025-07-06 16:18:31 +08:00
Christian Clauss
d2f7e55bf5 Run the tests on pull requests 2025-07-05 13:57:58 +02:00
Christian Clauss
9f31df7f3a pytest recommendation: pip install --editable .
https://docs.pytest.org/en/stable/how-to/existingtestsuite.html

This makes setting `PYTHONPATH` unnecessary and will pull requirements from `pyproject.toml` instead of `requirements.txt`, so it is similar to end-user installations.

`makedir -p data/plugins` will do both `mkdir data` and `mkdir data/plugins`.

The `$CI` environment variable might be better to use than `$TESTING` because it is preset to `true` in GitHub Actions.
* https://docs.github.com/en/actions/reference/variables-reference#default-environment-variables
* https://docs.pytest.org/en/stable/explanation/ci.html
2025-07-05 13:52:28 +02:00
Soulter
b8c1b53d67 Merge pull request #2034 from AstrBotDevs/dependabot/github_actions/github-actions-50e66c4123
chore(deps): bump the github-actions group with 4 updates
2025-07-05 19:24:16 +08:00
dependabot[bot]
2495837791 chore(deps): bump the github-actions group with 4 updates
Bumps the github-actions group with 4 updates: [actions/checkout](https://github.com/actions/checkout), [actions/setup-python](https://github.com/actions/setup-python), [codecov/codecov-action](https://github.com/codecov/codecov-action) and [actions/stale](https://github.com/actions/stale).


Updates `actions/checkout` from 3 to 4
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

Updates `actions/setup-python` from 4 to 5
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v4...v5)

Updates `codecov/codecov-action` from 4 to 5
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/codecov/codecov-action/compare/v4...v5)

Updates `actions/stale` from 5 to 9
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v9)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: codecov/codecov-action
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/stale
  dependency-version: '9'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-07-05 11:20:25 +00:00
Soulter
b6562e3c47 Merge pull request #2005 from cclauss/patch-1
Keep GitHub Actions up to date with GitHub's Dependabot
2025-07-05 19:19:18 +08:00
Soulter
c57da046ee Merge pull request #2013 from AstrBotDevs/feat/danger-plugin
[Copilot] feat: 添加风险插件安装确认对话框以及风险插件标签特殊处理
2025-07-05 19:18:14 +08:00
JOJO
ff63134c14 fix: 增加演示模式下校验插件开启/关闭/安装指令 2025-07-05 12:43:19 +08:00
鸦羽
3f5210c587 chore: update plugin publish template 2025-07-04 22:28:00 +08:00
IGCrystal
3df5e7b9b9 🐞 fix: 添加tags.danger的翻译键 2025-07-04 17:28:39 +08:00
Soulter
225db66738 fix: refine streaming logic in chat response handling 2025-07-04 16:59:49 +08:00
Soulter
383ebb8f57 feat: add copy functionality for bot messages with success feedback 2025-07-04 16:27:52 +08:00
Raven95676
e1bed60f1f fix: adjust timing of adding to star_registry 2025-07-04 16:13:10 +08:00
Raven95676
edbb856023 refactor: deprecate register_star decorator 2025-07-04 15:54:23 +08:00
Raven95676
98d3ab646f chore: convert some methods to static 2025-07-04 15:07:14 +08:00
Soulter
81be556f1b Merge pull request #2018 from AstrBotDevs/fix-extension-btn-z-index
Fix: adjust z-index for the add button on ExtensionPage
2025-07-04 11:41:10 +08:00
Soulter
f45a085469 fix: adjust z-index for the add button on ExtensionPage
fixes: #1985
2025-07-04 11:40:14 +08:00
Raven95676
210cc58cc3 fix: 更新风险插件警告对话框内容和按钮文本,修正样式 By @Soulter
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-07-04 11:23:19 +08:00
Soulter
1063b11ef6 fix: check provider availability errors on dify 2025-07-04 10:19:58 +08:00
Raven95676
a4e999c47f feat: 添加风险插件安装确认对话框以及风险插件标签特殊处理 2025-07-03 22:16:00 +08:00
Soulter
543e01c301 perf: webui 删除对话使用 conversation_mgr,以保持状态同步 2025-07-03 15:44:45 +08:00
Soulter
14e0aa3ec5 perf: history 和 persona 指令当对话不存在的时候自动创建
fixes: #1997
2025-07-03 15:40:00 +08:00
Christian Clauss
1a8a171f8b Keep GitHub Actions up to date with GitHub's Dependabot
* [Keeping your software supply chain secure with Dependabot](https://docs.github.com/en/code-security/dependabot)
* [Keeping your actions up to date with Dependabot](https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot)
* [Configuration options for the `dependabot.yml` file - package-ecosystem](https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem)
2025-07-03 08:46:42 +02:00
Soulter
f1954f9a43 Merge pull request #1984 from RC-CHN/master
refactor:将前端测试供应商部分修改为独立并发异步获取各个文本供应商的状态
2025-07-03 10:55:51 +08:00
Soulter
441b148501 Merge pull request #1991 from AstrBotDevs/perf/webchat-title
perf: 优化WebChat对话标题生成
2025-07-03 10:53:35 +08:00
Soulter
bd0f30b81c Merge pull request #2003 from AstrBotDevs/feat-webchat-select-provider
Feature: WebChat 增加可选择提供商和模型的功能
2025-07-03 10:52:42 +08:00
Soulter
ad14e9bf40 chore: remove unnecessary logging of payloads in chat completion 2025-07-03 10:50:03 +08:00
Soulter
6f71301aaf fix: log error when selected provider is not found 2025-07-03 10:49:12 +08:00
Soulter
5f0d601baa feat: add support for selecting provider and models in webchat 2025-07-03 10:42:20 +08:00
Soulter
f234a5bcc2 fix: enhance event hook handling to return status and prevent propagation 2025-07-03 00:23:56 +08:00
chenpeng
ab677ea100 修正pilk依赖提示文案
增加监听wechatpadpro消息平台的事件
2025-07-02 17:30:37 +08:00
Soulter
f3ad53e949 feat: add supports for selecting provider and models in webchat 2025-07-02 17:12:30 +08:00
Soulter
d324cfa84d Merge pull request #1987 from AstrBotDevs/refactor-webchat-streaming
Refactor: 重构 WebChat 的 SSE 监听逻辑
2025-07-02 17:11:12 +08:00
Soulter
dd4319d72a Merge pull request #1990 from AstrBotDevs/fix-stream-multi-tool-use-err
fix: Multi-turn tools use error when using streaming output
2025-07-02 15:44:29 +08:00
Raven95676
1f2de3d3d8 perf: 优化WebChat对话标题生成 2025-07-02 10:43:54 +08:00
Raven95676
72702beb0b chore: clean code 2025-07-02 10:29:10 +08:00
Soulter
adb0cbc5dd fix: handle tool_calls_result as list or single object in context query in streaming mode 2025-07-02 10:16:44 +08:00
Soulter
6a503b82c3 refactor: web chat queue management and streamline chat route handling 2025-07-01 22:34:17 +08:00
advent259141
28a87351f1 新增对会话重命名的功能 2025-07-01 21:41:19 +08:00
Soulter
bcc97378b0 feat: implement code copy functionality and enhance code highlighting in ChatPage 2025-07-01 21:15:01 +08:00
Soulter
eb8a138713 feat: enhance conversation actions with delete functionality and improved styling 2025-07-01 21:00:43 +08:00
advent259141
dcd7dcbbdf 解决了conflict 2025-07-01 17:24:56 +08:00
Gao Jinzhe
1538759ba7 Merge branch 'master' into Astrbot_session_manage 2025-07-01 17:19:30 +08:00
Soulter
30e8ea7fd8 chore: add deploy badge 2025-07-01 16:59:58 +08:00
Ruochen
879b7b582c perf:提取重复的错误处理逻辑,优化循环调用 2025-07-01 16:02:56 +08:00
Ruochen
8ba4236402 refactor:将前端测试供应商部分修改为独立并发异步获取各个文本供应商的状态 2025-07-01 15:41:30 +08:00
鸦羽
5eef8fa9b9 Merge pull request #1981 from AstrBotDevs/feat/r1_filter-integration
feat: 集成r1_filter至框架
2025-07-01 13:56:01 +08:00
Raven95676
d03d035437 perf: 合并嵌套的if条件 2025-07-01 13:53:22 +08:00
Raven95676
68e8e1f70b feat: 集成r1_filter至框架 2025-07-01 12:40:52 +08:00
Soulter
7acb45b157 Update README.md 2025-07-01 11:35:14 +08:00
Soulter
c36142deaf perf: chatpage UI 2025-06-30 15:20:46 +08:00
Soulter
5fd6e316fa Merge pull request #1966 from railgun19457/master
修改了一对大括号
2025-06-30 13:33:10 +08:00
railgun19457
39a9d7765a 修改了一对大括号 2025-06-30 00:21:28 +08:00
advent259141
ec5d71d0e1 修复了一下重复的代码问题,删除了不必要的会话级别 LLM 启停状态检查。 2025-06-29 10:02:04 +08:00
advent259141
d121d08d05 大致凭借自己理解修复了一下整个检查流程,防止钩子出现问题 2025-06-29 09:57:31 +08:00
Gao Jinzhe
be08f4a558 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-29 09:11:25 +08:00
Soulter
4df8606ab6 style: code quality 2025-06-28 20:08:57 +08:00
Soulter
71442d26ec chore: 移除不必要的 MCP 会话控制 2025-06-28 19:58:36 +08:00
advent259141
4f5528869c Merge branch 'Astrbot_session_manage' of https://github.com/advent259141/AstrBot into Astrbot_session_manage 2025-06-28 17:00:00 +08:00
advent259141
f16feff17b 根据会话mcp开关情况选择性传入 func_tool
修改import的位置
deleted:    astrbot/core/star/session_tts_manager.py
复原被覆盖的修改
2025-06-28 16:59:00 +08:00
Gao Jinzhe
d8aae538cd Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-28 14:55:38 +08:00
advent259141
31670e75e5 Merge branch 'Astrbot_session_manage' of https://github.com/advent259141/AstrBot into Astrbot_session_manage 2025-06-27 18:47:25 +08:00
advent259141
ed6011a2be modified: dashboard/src/i18n/loader.ts
modified:   dashboard/src/i18n/locales/en-US/core/navigation.json
增加会话管理英文页面
	modified:   dashboard/src/i18n/locales/zh-CN/core/navigation.json
增加会话管理中文页面
	modified:   dashboard/src/i18n/translations.ts
	modified:   dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
	modified:   dashboard/src/views/SessionManagementPage.vue
增加会话管理国际化适配
2025-06-27 18:46:02 +08:00
Gao Jinzhe
cdded38ade Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-27 17:10:08 +08:00
advent259141
f536f24833 astrbot/core/pipeline/process_stage/method/llm_request.py
astrbot/core/pipeline/result_decorate/stage.py
   astrbot/core/star/session_llm_manager.py
   astrbot/core/star/session_tts_manager.py
   astrbot/dashboard/routes/session_management.py
   astrbot/dashboard/server.py
   dashboard/src/views/SessionManagementPage.vue
   packages/astrbot/main.py
2025-06-27 17:08:05 +08:00
Gao Jinzhe
646b18d910 Merge branch 'AstrBotDevs:master' into master 2025-06-27 12:26:15 +08:00
Gao Jinzhe
e24225c828 Merge branch 'master' into master 2025-06-23 15:21:08 +08:00
Gao Jinzhe
50a296de20 Merge branch 'AstrBotDevs:master' into master 2025-06-20 14:39:57 +08:00
Gao Jinzhe
c79e38e044 Merge branch 'AstrBotDevs:master' into master 2025-06-17 20:29:32 +08:00
Gao Jinzhe
dae745d925 Update server.py 2025-06-16 10:03:18 +08:00
Gao Jinzhe
791db65526 resolve conflict with master branch 2025-06-16 09:50:35 +08:00
Gao Jinzhe
02e2e617f5 Merge branch 'AstrBotDevs:master' into master 2025-06-15 22:04:06 +08:00
advent259141
bfc8024119 modified: astrbot/core/pipeline/process_stage/method/llm_request.py
new file:   astrbot/core/star/session_llm_manager.py
	modified:   astrbot/dashboard/routes/session_management.py
	modified:   dashboard/src/views/SessionManagementPage.vue

 增加了精确到会话的LLM启停管理以及插件启停管理
2025-06-14 03:42:21 +08:00
Gao Jinzhe
f26cf6ed6f Merge branch 'AstrBotDevs:master' into master 2025-06-14 03:03:41 +08:00
advent259141
f2be55bd8e Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 06:20:49 +08:00
advent259141
d241dd17ca Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 06:20:09 +08:00
advent259141
cecafdfe6c Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 03:54:35 +08:00
Soulter
6fecfd1a0e Merge pull request #1800 from AstrBotDevs/feat-weixinkefu-record
feat: 微信客服支持语音的收发
2025-06-13 03:52:15 +08:00
241 changed files with 23628 additions and 10163 deletions

View File

@@ -1,30 +0,0 @@
---
name: '🥳 发布插件'
title: "[Plugin] 插件名"
about: 提交插件到插件市场
labels: [ "plugin-publish" ]
assignees: ''
---
欢迎发布插件到插件市场!
## 插件基本信息
请将插件信息填写到下方的 Json 代码块中。`tags`(插件标签)和 `social_link`(社交链接)选填。
```json
{
"name": "插件名",
"desc": "插件介绍",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
## 检查
- [ ] 我的插件经过完整的测试
- [ ] 我的插件不包含恶意代码
- [ ] 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。

View File

@@ -0,0 +1,56 @@
name: 🥳 发布插件
description: 提交插件到插件市场
title: "[Plugin] 插件名"
labels: ["plugin-publish"]
assignees: []
body:
- type: markdown
attributes:
value: |
欢迎发布插件到插件市场!
- type: markdown
attributes:
value: |
## 插件基本信息
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
- type: textarea
id: plugin-info
attributes:
label: 插件信息
description: 请在下方代码块中填写您的插件信息确保反引号包裹了JSON
value: |
```json
{
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
validations:
required: true
- type: markdown
attributes:
value: |
## 检查
- type: checkboxes
id: checks
attributes:
label: 插件检查清单
description: 请确认以下所有项目
options:
- label: 我的插件经过完整的测试
required: true
- label: 我的插件不包含恶意代码
required: true
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true

36
.github/auto_assign.yml vendored Normal file
View File

@@ -0,0 +1,36 @@
# Set to true to add reviewers to pull requests
addReviewers: true
# Set to true to add assignees to pull requests
addAssignees: false
# A list of reviewers to be added to pull requests (GitHub user name)
reviewers:
- Soulter
- Raven95676
- Larch-C
- anka-afk
- advent259141
# - zouyonghe
# A number of reviewers added to the pull request
# Set 0 to add all the reviewers (default: 0)
numberOfReviewers: 2
# A list of assignees, overrides reviewers if set
# assignees:
# - assigneeA
# A number of assignees to add to the pull request
# Set to 0 to add all of the assignees.
# Uses numberOfReviewers if unset.
# numberOfAssignees: 2
# A list of keywords to be skipped the process that add reviewers if pull requests include it
skipKeywords:
- wip
- draft
# A list of users to be skipped by both the add reviewers and add assignees processes
# skipUsers:
# - dependabot[bot]

63
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,63 @@
# AstrBot Development Instructions
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
## Working Effectively
### Bootstrap and Install Dependencies
- **Python 3.10+ required** - Check `.python-version` file
- Install UV package manager: `pip install uv`
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
- Create required directories: `mkdir -p data/plugins data/config data/temp`
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
- Navigate to dashboard: `cd dashboard`
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
- Dashboard creates optimized production build in `dashboard/dist/`
### Testing
- Do not generate test files for now.
### Code Quality and Linting
- Install ruff linter: `uv add --dev ruff`
- Check code style: `uv run ruff check .` -- takes <1 second
- Check formatting: `uv run ruff format --check .` -- takes <1 second
- Fix formatting: `uv run ruff format .`
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
### Common Issues and Workarounds
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
## CI/CD Integration
- GitHub Actions workflows in `.github/workflows/`
- Docker builds supported via `Dockerfile`
- Pre-commit hooks enforce ruff formatting and linting
## Docker Support
- Primary deployment method: `docker run soulter/astrbot:latest`
- Compose file available: `compose.yml`
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
- Volume mount required: `./data:/AstrBot/data`
## Multi-language Support
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
- UI supports internationalization
- Default language is Chinese
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.

13
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,13 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly

View File

@@ -13,7 +13,7 @@ jobs:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: Dashboard Build
run: |
@@ -70,10 +70,10 @@ jobs:
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: '3.10'

View File

@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -8,6 +8,7 @@ on:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
@@ -16,30 +17,29 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
pip install pytest pytest-asyncio pytest-cov
pip install --editable .
- name: Run tests
run: |
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
mkdir -p data/plugins
mkdir -p data/config
mkdir -p data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: npm install, build
run: |
@@ -25,6 +25,8 @@ jobs:
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
mkdir -p dashboard/dist/assets
echo $COMMIT_SHA > dashboard/dist/assets/version
cd dashboard
zip -r dist.zip dist
- name: Archive production artifacts
uses: actions/upload-artifact@v4
@@ -33,3 +35,13 @@ jobs:
path: |
dashboard/dist
!dist/**/*.md
- name: Create GitHub Release
uses: ncipollo/release-action@v1
with:
tag: release-${{ github.sha }}
owner: AstrBotDevs
repo: astrbot-release-harbour
body: "Automated release from commit ${{ github.sha }}"
token: ${{ secrets.ASTRBOT_HARBOUR_TOKEN }}
artifacts: "dashboard/dist.zip"

View File

@@ -12,7 +12,7 @@ jobs:
steps:
- name: Pull The Codes
uses: actions/checkout@v3
uses: actions/checkout@v5
with:
fetch-depth: 0 # Must be 0 so we can fetch tags
@@ -27,6 +27,33 @@ jobs:
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Check if version is pre-release
id: check-prerelease
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
version="${{ steps.get-latest-tag.outputs.latest_tag }}"
else
version="${{ github.ref_name }}"
fi
if [[ "$version" == *"beta"* ]] || [[ "$version" == *"alpha"* ]]; then
echo "is_prerelease=true" >> $GITHUB_OUTPUT
echo "Version $version is a pre-release, will not push latest tag"
else
echo "is_prerelease=false" >> $GITHUB_OUTPUT
echo "Version $version is a stable release, will push latest tag"
fi
- name: Build Dashboard
run: |
cd dashboard
npm install
npm run build
mkdir -p dist/assets
echo $(git rev-parse HEAD) > dist/assets/version
cd ..
mkdir -p data
cp -r dashboard/dist data/
- name: Set QEMU
uses: docker/setup-qemu-action@v3
@@ -53,9 +80,9 @@ jobs:
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && format('{0}/astrbot:latest', secrets.DOCKER_HUB_USERNAME) || '' }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
ghcr.io/soulter/astrbot:latest
${{ steps.check-prerelease.outputs.is_prerelease == 'false' && 'ghcr.io/soulter/astrbot:latest' || '' }}
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- name: Post build notifications

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v5
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'

177
README.md
View File

@@ -1,4 +1,4 @@
<p align="center">
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
@@ -6,8 +6,6 @@
<div align="center">
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
@@ -16,7 +14,6 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
@@ -25,59 +22,52 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
## 主要功能
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
5. **WebUI**。可视化配置和管理机器人,功能齐全。
> [!WARNING]
>
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
## ✨ 近期更新
<details><summary>1. AstrBot 现已自带知识库能力</summary>
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
![image](https://github.com/user-attachments/assets/28b639b0-bb5c-4958-8e94-92ae8cfd1ab4)
</details>
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能
> [!NOTE]
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper
2. **多消息平台接入**。支持接入 QQOneBot、QQ 官方机器人平台、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP]
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。
## ✨ 使用方式
## 部署方式
#### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### 1Panel 部署
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
#### 在 雨云 上部署
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### 在 Replit 上部署
社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### Windows 一键安装器部署
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### 宝塔面板部署
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### CasaOS 部署
社区贡献的部署方式。
@@ -86,9 +76,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
#### 手动部署
> 推荐使用 `uv`。
首先,安装 uv
首先安装 uv
```bash
pip install uv
@@ -101,19 +89,27 @@ git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
或者,直接通过 uvx 安装 AstrBot
```bash
mkdir astrbot && cd astrbot
uvx astrbot init
# uvx astrbot run
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### Replit 部署
## 🌍 社区
### QQ 群组
- 1 群322154837
- 3 群630166526
- 5 群822130018
- 6 群753075035
- 开发者群975206796
- 开发者群备份295657329
### Telegram 群组
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群组
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ⚡ 消息平台支持情况
@@ -121,7 +117,6 @@ uvx astrbot init
| -------- | ------- |
| QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ |
| 微信个人号 | ✔ |
| Telegram | ✔ |
| 企业微信 | ✔ |
| 微信客服 | ✔ |
@@ -132,22 +127,20 @@ uvx astrbot init
| Discord | ✔ |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
| 微信对话开放平台 | 🚧 |
| WhatsApp | 🚧 |
| 小爱音响 | 🚧 |
## ⚡ 提供商支持情况
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
| Anthropic | ✔ | 文本生成 | |
| Google Gemini | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
| 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
@@ -161,7 +154,6 @@ uvx astrbot init
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
@@ -180,38 +172,6 @@ pip install pre-commit
pre-commit install
```
## 🌟 支持
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
## ✨ Demo
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器Beta 测试✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
_✨ WebUI ✨_
</div>
</details>
## ❤️ Special Thanks
@@ -221,10 +181,18 @@ _✨ WebUI ✨_
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
此外,本项目的诞生离不开以下开源项目:
此外,本项目的诞生离不开以下开源项目的帮助
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
另外,一些同类型其他的活跃开源 Bot 项目:
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
## ⭐ Star History
@@ -237,14 +205,9 @@ _✨ WebUI ✨_
</div>
![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097)
</details>
## Disclaimer
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
_私は、高性能ですから!_

View File

@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
## ✨ 主な機能
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechatFeishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## 免責事項
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
2. WeChat個人アカウントのデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください
<!-- ## ✨ ATRI [ベータテスト]
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
4. TTS
-->
_私は、高性能ですから!_

View File

@@ -3,5 +3,18 @@ from astrbot import logger
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_agent as agent
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]
__all__ = [
"AstrBotConfig",
"logger",
"html_renderer",
"llm_tool",
"agent",
"sp",
"ToolSet",
"FunctionTool",
"BaseFunctionToolExecutor",
]

View File

@@ -7,6 +7,7 @@ from astrbot.core.star.register import (
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_astrbot_loaded as on_astrbot_loaded,
register_on_platform_loaded as on_platform_loaded,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
@@ -41,6 +42,7 @@ __all__ = [
"custom_filter",
"PermissionType",
"on_astrbot_loaded",
"on_platform_loaded",
"on_llm_request",
"llm_tool",
"on_decorating_result",

View File

@@ -1 +1 @@
__version__ = "3.5.8"
__version__ = "3.5.23"

View File

@@ -37,7 +37,10 @@ async def check_dashboard(astrbot_root: Path) -> None:
):
click.echo("正在安装管理面板...")
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
click.echo("管理面板安装完成")
@@ -50,7 +53,10 @@ async def check_dashboard(astrbot_root: Path) -> None:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
@@ -59,7 +65,10 @@ async def check_dashboard(astrbot_root: Path) -> None:
click.echo("初始化管理面板目录...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
path=str(astrbot_root / "dashboard.zip"),
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
click.echo("管理面板初始化完成")
except Exception as e:

View File

@@ -117,6 +117,9 @@ def build_plug_list(plugins_dir: Path) -> list:
# 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]

View File

@@ -1,5 +1,4 @@
import os
import asyncio
from .log import LogManager, LogBroker # noqa
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
@@ -21,12 +20,10 @@ html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
db_helper = SQLiteDatabase(DB_PATH)
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences()
sp = SharedPreferences(db_helper=db_helper)
# 文件令牌服务
file_token_service = FileTokenService()
pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None),
)
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass
from .tool import FunctionTool
from typing import Generic
from .run_context import TContext
from .hooks import BaseAgentRunHooks
@dataclass
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None

View File

@@ -0,0 +1,34 @@
from typing import Generic
from .tool import FunctionTool
from .agent import Agent
from .run_context import TContext
class HandoffTool(FunctionTool, Generic[TContext]):
"""Handoff tool for delegating tasks to another agent."""
def __init__(
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
):
self.agent = agent
super().__init__(
name=f"transfer_to_{agent.name}",
parameters=parameters or self.default_parameters(),
description=agent.instructions or self.default_description(agent.name),
**kwargs,
)
def default_parameters(self) -> dict:
return {
"type": "object",
"properties": {
"input": {
"type": "string",
"description": "The input to be handed off to another agent. This should be a clear and concise request or task.",
},
},
}
def default_description(self, agent_name: str | None) -> str:
agent_name = agent_name or "another"
return f"Delegate tasks to {self.name} agent to handle the request."

View File

@@ -0,0 +1,27 @@
import mcp
from dataclasses import dataclass
from .run_context import ContextWrapper, TContext
from typing import Generic
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.agent.tool import FunctionTool
@dataclass
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_tool_start(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
): ...
async def on_tool_end(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
tool_result: mcp.types.CallToolResult | None,
): ...
async def on_agent_done(
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
): ...

View File

@@ -0,0 +1,208 @@
import asyncio
import logging
from datetime import timedelta
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core.utils.log_pipe import LogPipe
try:
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在:
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
# 处理 MCP 服务的错误日志
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
), # type: ignore
),
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done

View File

@@ -0,0 +1,12 @@
from dataclasses import dataclass
import typing as T
from astrbot.core.message.message_event_result import MessageChain
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData

View File

@@ -0,0 +1,17 @@
from dataclasses import dataclass
from typing import Any, Generic
from typing_extensions import TypeVar
from astrbot.core.platform.astr_message_event import AstrMessageEvent
TContext = TypeVar("TContext", default=Any)
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
context: TContext
event: AstrMessageEvent
NoContext = ContextWrapper[None]

View File

@@ -0,0 +1,3 @@
from .base import BaseAgentRunner
__all__ = ["BaseAgentRunner"]

View File

@@ -1,32 +1,33 @@
import abc
import typing as T
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
from enum import Enum, auto
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponse
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import LLMResponse
class AgentState(Enum):
"""Agent 状态枚举"""
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
"""Defines the state of the agent."""
IDLE = auto() # Initial state
RUNNING = auto() # Currently processing
DONE = auto() # Completed
ERROR = auto() # Error state
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData
class BaseAgentRunner:
class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(self) -> None:
async def reset(
self,
provider: Provider,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
"""
Reset the agent to its initial state.
This method should be called before starting a new run.

View File

@@ -1,10 +1,12 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
from ...context import PipelineContext
from .base import BaseAgentRunner, AgentResponse, AgentState
from ..hooks import BaseAgentRunHooks
from ..tool_executor import BaseFunctionToolExecutor
from ..run_context import ContextWrapper, TContext
from ..response import AgentResponseData
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageChain,
)
@@ -21,8 +23,8 @@ from mcp.types import (
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
CallToolResult,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
if sys.version_info >= (3, 12):
@@ -31,28 +33,25 @@ else:
from typing_extensions import override
# TODO:
# 1. 处理平台不兼容的处理器
class ToolLoopAgent(BaseAgentRunner):
def __init__(
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
) -> None:
self.provider = provider
self.req = None
self.event = event
self.pipeline_ctx = pipeline_ctx
self._state = AgentState.IDLE
self.final_llm_resp = None
self.streaming = False
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
self.req = req
self.streaming = streaming
async def reset(
self,
provider: Provider,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.provider = provider
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
self.agent_hooks = agent_hooks
self.run_context = run_context
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
@@ -78,6 +77,12 @@ class ToolLoopAgent(BaseAgentRunner):
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
@@ -106,7 +111,6 @@ class ToolLoopAgent(BaseAgentRunner):
# 处理 LLM 响应
llm_resp = llm_resp_result
logger.debug(f"LLMResp: {llm_resp}")
if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态
@@ -125,11 +129,10 @@ class ToolLoopAgent(BaseAgentRunner):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# 执行事件钩子
await self.pipeline_ctx.call_event_hook(
self.event, EventType.OnLLMResponseEvent, llm_resp
)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回 LLM 结果
if llm_resp.result_chain:
@@ -193,14 +196,23 @@ class ToolLoopAgent(BaseAgentRunner):
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
try:
await self.agent_hooks.on_tool_start(
self.run_context, func_tool, func_tool_args
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if not res:
continue
except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
executor = self.tool_executor.execute(
tool=func_tool,
run_context=self.run_context,
**func_tool_args,
)
async for resp in executor:
if isinstance(resp, CallToolResult):
res = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -218,7 +230,9 @@ class ToolLoopAgent(BaseAgentRunner):
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain().base64_image(res.content[0].data)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
@@ -242,7 +256,9 @@ class ToolLoopAgent(BaseAgentRunner):
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain().base64_image(res.content[0].data)
yield MessageChain(
type="tool_direct_result"
).base64_image(res.content[0].data)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -252,32 +268,50 @@ class ToolLoopAgent(BaseAgentRunner):
)
)
yield MessageChain().message("返回的数据类型不受支持。")
else:
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
# 尝试调用工具函数
wrapper = self.pipeline_ctx.call_handler(
self.event, func_tool.handler, **func_tool_args
try:
await self.agent_hooks.on_tool_end(
self.run_context,
func_tool_name,
func_tool_args,
resp,
)
async for resp in wrapper:
if resp is not None:
# Tool 返回结果
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
)
yield MessageChain().message(resp)
else:
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res := self.run_context.event.get_result():
if res.chain:
yield MessageChain(chain=res.chain)
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
else:
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
self.event.clear_result()
try:
await self.agent_hooks.on_tool_end(
self.run_context, func_tool_name, func_tool_args, None
)
except Exception as e:
logger.error(
f"Error in on_tool_end hook: {e}", exc_info=True
)
self.run_context.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(

256
astrbot/core/agent/tool.py Normal file
View File

@@ -0,0 +1,256 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from .mcp_client import MCPClient
@dataclass
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str | None = None
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str | None = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient | None = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
def __dict__(self) -> dict[str, Any]:
"""将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
class ToolSet:
"""A set of function tools that can be used in function calling.
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: list[FunctionTool] = None):
self.tools: list[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool):
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
if existing_tool.name == tool.name:
self.tools[i] = tool
return
self.tools.append(tool)
def remove_tool(self, name: str):
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
def get_tool(self, name: str) -> Optional[FunctionTool]:
"""Get a tool by its name."""
for tool in self.tools:
if tool.name == name:
return tool
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FunctionTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.add_tool(_func)
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
def remove_func(self, name: str):
"""Remove a function tool by its name."""
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
"""Get all function tools."""
return self.get_tool(name)
@property
def func_list(self) -> list[FunctionTool]:
"""Get the list of function tools."""
return self.tools
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
"""Convert tools to OpenAI API function calling schema format."""
result = []
for tool in self.tools:
func_def = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
},
}
if tool.parameters.get("properties") or not omit_empty_parameter_field:
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result
def anthropic_schema(self) -> list[dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
}
result.append(tool_def)
return result
def google_schema(self) -> dict:
"""Convert tools to Google GenAI API format."""
def convert_schema(schema: dict) -> dict:
"""Convert schema to Gemini API format."""
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties:
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
return self.openai_schema(omit_empty_parameter_field)
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
def get_func_desc_anthropic_style(self):
return self.anthropic_schema()
@deprecated(reason="Use google_schema() instead", version="4.0.0")
def get_func_desc_google_genai_style(self):
return self.google_schema()
def names(self) -> list[str]:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]
def __len__(self):
return len(self.tools)
def __bool__(self):
return len(self.tools) > 0
def __iter__(self):
return iter(self.tools)
def __repr__(self):
return f"ToolSet(tools={self.tools})"
def __str__(self):
return f"ToolSet(tools={self.tools})"

View File

@@ -0,0 +1,11 @@
import mcp
from typing import Any, Generic, AsyncGenerator
from .run_context import TContext, ContextWrapper
from .tool import FunctionTool
class BaseFunctionToolExecutor(Generic[TContext]):
@classmethod
async def execute(
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...

View File

@@ -0,0 +1,11 @@
from dataclasses import dataclass
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@dataclass
class AstrAgentContext:
provider: Provider
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool

View File

@@ -0,0 +1,283 @@
import os
import uuid
from astrbot.core import AstrBotConfig, logger
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from typing import TypeVar, TypedDict
_VT = TypeVar("_VT")
class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default"
umop: list[str] # Unified Message Origin Pattern
name: str
path: str # File name to the configuration file
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
id="default",
umop=["::"],
name="default",
path=ASTRBOT_CONFIG_PATH,
)
class AstrBotConfigManager:
"""A class to manage the system configuration of AstrBot, aka ACM"""
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
self.sp = sp
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
self.abconf_data = None
self._load_all_configs()
def _get_abconf_data(self) -> dict:
"""获取所有的 abconf 数据"""
if self.abconf_data is None:
self.abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
return self.abconf_data
def _load_all_configs(self):
"""Load all configurations from the shared preferences."""
abconf_data = self._get_abconf_data()
self.abconf_data = abconf_data
for uuid_, meta in abconf_data.items():
filename = meta["path"]
conf_path = os.path.join(get_astrbot_config_path(), filename)
if os.path.exists(conf_path):
conf = AstrBotConfig(config_path=conf_path)
self.confs[uuid_] = conf
else:
logger.warning(
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
)
continue
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "umop": list, "path": str, "name": str }
abconf_data = self._get_abconf_data()
if isinstance(umo, MessageSession):
umo = str(umo)
else:
try:
umo = str(MessageSession.from_str(umo)) # validate
except Exception:
return DEFAULT_CONFIG_CONF_INFO
for uuid_, meta in abconf_data.items():
for pattern in meta["umop"]:
if self._is_umo_match(pattern, umo):
return ConfInfo(**meta, id=uuid_)
return DEFAULT_CONFIG_CONF_INFO
def _save_conf_mapping(
self,
abconf_path: str,
abconf_id: str,
umo_parts: list[str] | list[MessageSession],
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"umop": umo_parts,
"path": abconf_path,
"name": random_word,
}
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
if not umo:
return self.confs["default"]
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
uuid_ = self._load_conf_mapping(umo)["id"]
conf = self.confs.get(uuid_)
if not conf:
conf = self.confs["default"] # default MUST exists
return conf
@property
def default_conf(self) -> AstrBotConfig:
"""获取默认配置文件"""
return self.confs["default"]
def get_conf_info(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件元数据"""
if isinstance(umo, MessageSession):
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
return self._load_conf_mapping(umo)
def get_conf_list(self) -> list[ConfInfo]:
"""获取所有配置文件的元数据列表"""
conf_list = []
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
abconf_mapping = self._get_abconf_data()
for uuid_, meta in abconf_mapping.items():
conf_list.append(ConfInfo(**meta, id=uuid_))
return conf_list
def create_conf(
self,
umo_parts: list[str] | list[MessageSession],
config: dict = DEFAULT_CONFIG,
name: str | None = None,
) -> str:
"""
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
"""
conf_uuid = str(uuid.uuid4())
conf_file_name = f"abconf_{conf_uuid}.json"
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
conf = AstrBotConfig(config_path=conf_path, default_config=config)
conf.save_config()
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
self.confs[conf_uuid] = conf
return conf_uuid
def delete_conf(self, conf_id: str) -> bool:
"""删除指定配置文件
Args:
conf_id: 配置文件的 UUID
Returns:
bool: 删除是否成功
Raises:
ValueError: 如果试图删除默认配置文件
"""
if conf_id == "default":
raise ValueError("不能删除默认配置文件")
# 从映射中移除
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 获取配置文件路径
conf_path = os.path.join(
get_astrbot_config_path(), abconf_data[conf_id]["path"]
)
# 删除配置文件
try:
if os.path.exists(conf_path):
os.remove(conf_path)
logger.info(f"已删除配置文件: {conf_path}")
except Exception as e:
logger.error(f"删除配置文件 {conf_path} 失败: {e}")
return False
# 从内存中移除
if conf_id in self.confs:
del self.confs[conf_id]
# 从映射中移除
del abconf_data[conf_id]
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功删除配置文件 {conf_id}")
return True
def update_conf_info(
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
) -> bool:
"""更新配置文件信息
Args:
conf_id: 配置文件的 UUID
name: 新的配置文件名称 (可选)
umo_parts: 新的 UMO 部分列表 (可选)
Returns:
bool: 更新是否成功
"""
if conf_id == "default":
raise ValueError("不能更新默认配置文件的信息")
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
if conf_id not in abconf_data:
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
return False
# 更新名称
if name is not None:
abconf_data[conf_id]["name"] = name
# 更新 UMO 部分
if umo_parts is not None:
# 验证 UMO 部分格式
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data[conf_id]["umop"] = umo_parts
# 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
logger.info(f"成功更新配置文件 {conf_id} 的信息")
return True
def g(
self, umo: str | None = None, key: str | None = None, default: _VT = None
) -> _VT:
"""获取配置项。umo 为 None 时使用默认配置"""
if umo is None:
return self.confs["default"].get(key, default)
conf = self.get_conf(umo)
return conf.get(key, default)

File diff suppressed because it is too large Load Diff

View File

@@ -5,40 +5,44 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation
from astrbot.core.db.po import Conversation, ConversationV2
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.session_conversations: Dict[str, str] = {}
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
"""将 ConversationV2 对象转换为 Conversation 对象"""
created_at = int(conv_v2.created_at.timestamp())
updated_at = int(conv_v2.updated_at.timestamp())
return Conversation(
platform_id=conv_v2.platform_id,
user_id=conv_v2.user_id,
cid=conv_v2.conversation_id,
history=json.dumps(conv_v2.content or []),
title=conv_v2.title,
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
)
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
async def new_conversation(
self,
unified_msg_origin: str,
platform_id: str | None = None,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
) -> str:
"""新建对话,并将当前会话的对话转移到新对话
Args:
@@ -46,11 +50,23 @@ class ConversationManager:
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
if not platform_id:
# 如果没有提供 platform_id则从 unified_msg_origin 中解析
parts = unified_msg_origin.split(":")
if len(parts) >= 3:
platform_id = parts[0]
if not platform_id:
platform_id = "unknown"
conv = await self.db.create_conversation(
user_id=unified_msg_origin,
platform_id=platform_id,
content=content,
title=title,
persona_id=persona_id,
)
self.session_conversations[unified_msg_origin] = conv.conversation_id
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话
@@ -60,10 +76,10 @@ class ConversationManager:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str = None
self, unified_msg_origin: str, conversation_id: str | None = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
@@ -71,13 +87,18 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
del self.session_conversations[unified_msg_origin]
sp.put("session_conversation", self.session_conversations)
f = True
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
if f:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID
Args:
@@ -85,11 +106,19 @@ class ConversationManager:
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
return self.session_conversations.get(unified_msg_origin, None)
ret = self.session_conversations.get(unified_msg_origin, None)
if not ret:
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
if ret:
self.session_conversations[unified_msg_origin] = ret
return ret
async def get_conversation(
self, unified_msg_origin: str, conversation_id: str
) -> Conversation:
self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation | None:
"""获取会话的对话
Args:
@@ -98,20 +127,74 @@ class ConversationManager:
Returns:
conversation (Conversation): 对话对象
"""
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
conv = await self.db.get_conversation_by_id(cid=conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
conv = await self.db.get_conversation_by_id(cid=conversation_id)
conv_res = None
if conv:
conv_res = self._convert_conv_from_v2_to_v1(conv)
return conv_res
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话
async def get_conversations(
self, unified_msg_origin: str | None = None, platform_id: str | None = None
) -> List[Conversation]:
"""获取对话列表
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
Returns:
conversations (List[Conversation]): 对话对象列表
"""
return self.db.get_conversations(unified_msg_origin)
convs = await self.db.get_conversations(
user_id=unified_msg_origin, platform_id=platform_id
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res
async def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[Conversation], int]:
"""获取过滤后的对话列表
Args:
page (int): 页码, 默认为 1
page_size (int): 每页大小, 默认为 20
platform_ids (list[str]): 平台 ID 列表, 可选
search_query (str): 搜索查询字符串, 可选
Returns:
conversations (list[Conversation]): 对话对象列表
"""
convs, cnt = await self.db.get_filtered_conversations(
page=page,
page_size=page_size,
platform_ids=platform_ids,
search_query=search_query,
**kwargs,
)
convs_res = []
for conv in convs:
conv_res = self._convert_conv_from_v2_to_v1(conv)
convs_res.append(conv_res)
return convs_res, cnt
async def update_conversation(
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
self,
unified_msg_origin: str,
conversation_id: str | None = None,
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
):
"""更新会话的对话
@@ -120,39 +203,54 @@ class ConversationManager:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if not conversation_id:
# 如果没有提供 conversation_id则获取当前的
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
await self.db.update_conversation(
cid=conversation_id,
history=json.dumps(history),
title=title,
persona_id=persona_id,
content=history,
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
async def update_conversation_title(
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
):
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
Deprecated:
Use `update_conversation` with `title` parameter instead.
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin, cid=conversation_id, title=title
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
title=title,
)
async def update_conversation_persona_id(
self, unified_msg_origin: str, persona_id: str
self,
unified_msg_origin: str,
persona_id: str,
conversation_id: str | None = None,
):
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
Deprecated:
Use `update_conversation` with `persona_id` parameter instead.
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
await self.update_conversation(
unified_msg_origin=unified_msg_origin,
conversation_id=conversation_id,
persona_id=persona_id,
)
async def get_human_readable_context(

View File

@@ -15,20 +15,23 @@ import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config
from . import astrbot_config, html_renderer
from asyncio import Queue
from typing import List
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core import logger, sp
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
@@ -47,12 +50,23 @@ class AstrBotCoreLifecycle:
self.db = db # 初始化数据库
# 设置代理
if self.astrbot_config.get("http_proxy", ""):
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
if proxy := os.environ.get("https_proxy"):
logger.debug(f"Using proxy: {proxy}")
os.environ["no_proxy"] = "localhost"
proxy_config = self.astrbot_config.get("http_proxy", "")
if proxy_config != "":
os.environ["https_proxy"] = proxy_config
os.environ["http_proxy"] = proxy_config
logger.debug(f"Using proxy: {proxy_config}")
# 设置 no_proxy
no_proxy_list = self.astrbot_config.get("no_proxy", [])
os.environ["no_proxy"] = ",".join(no_proxy_list)
else:
# 清空代理环境变量
if "https_proxy" in os.environ:
del os.environ["https_proxy"]
if "http_proxy" in os.environ:
del os.environ["http_proxy"]
if "no_proxy" in os.environ:
del os.environ["no_proxy"]
logger.debug("HTTP proxy cleared")
async def initialize(self):
"""
@@ -66,11 +80,26 @@ class AstrBotCoreLifecycle:
else:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
await self.db.initialize()
await html_renderer.initialize()
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config, sp=sp
)
# 初始化事件队列
self.event_queue = Queue()
# 初始化人格管理器
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
await self.persona_mgr.initialize()
# 初始化供应商管理器
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
self.provider_manager = ProviderManager(
self.astrbot_config_mgr, self.db, self.persona_mgr
)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
@@ -78,6 +107,9 @@ class AstrBotCoreLifecycle:
# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
# 初始化平台消息历史管理器
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
@@ -86,6 +118,9 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.platform_message_history_manager,
self.persona_mgr,
self.astrbot_config_mgr,
)
# 初始化插件管理器
@@ -98,16 +133,16 @@ class AstrBotCoreLifecycle:
await self.provider_manager.initialize()
# 初始化消息事件流水线调度器
self.pipeline_scheduler = PipelineScheduler(
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
self.event_bus = EventBus(
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
)
# 记录启动时间
self.start_time = int(time.time())
@@ -224,6 +259,39 @@ class AstrBotCoreLifecycle:
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
asyncio.create_task(
platform_inst.run(),
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
)
)
return tasks
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
"""加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
mapping = {}
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
mapping[conf_id] = scheduler
return mapping
async def reload_pipeline_scheduler(self, conf_id: str):
"""重新加载消息事件流水线调度器
Returns:
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
"""
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
if not ab_config:
raise ValueError(f"配置文件 {conf_id} 不存在")
scheduler = PipelineScheduler(
PipelineContext(ab_config, self.plugin_manager, conf_id)
)
await scheduler.initialize()
self.pipeline_scheduler_mapping[conf_id] = scheduler

View File

@@ -1,7 +1,20 @@
import abc
import datetime
import typing as T
from deprecated import deprecated
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
from astrbot.core.db.po import (
Stats,
PlatformStat,
ConversationV2,
PlatformMessageHistory,
Attachment,
Persona,
Preference,
)
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
@dataclass
@@ -10,152 +23,262 @@ class BaseDatabase(abc.ABC):
数据库基类
"""
DATABASE_URL = ""
def __init__(self) -> None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.engine, class_=AsyncSession, expire_on_commit=False
)
async def initialize(self):
"""初始化数据库连接"""
pass
def insert_base_metrics(self, metrics: dict):
"""插入基础指标数据"""
self.insert_platform_metrics(metrics["platform_stats"])
self.insert_plugin_metrics(metrics["plugin_stats"])
self.insert_command_metrics(metrics["command_stats"])
self.insert_llm_metrics(metrics["llm_stats"])
@abc.abstractmethod
def insert_platform_metrics(self, metrics: dict):
"""插入平台指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_plugin_metrics(self, metrics: dict):
"""插入插件指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_command_metrics(self, metrics: dict):
"""插入指令指标数据"""
raise NotImplementedError
@abc.abstractmethod
def insert_llm_metrics(self, metrics: dict):
"""插入 LLM 指标数据"""
raise NotImplementedError
@abc.abstractmethod
def update_llm_history(self, session_id: str, content: str, provider_type: str):
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
raise NotImplementedError
@abc.abstractmethod
def get_llm_history(
self, session_id: str = None, provider_type: str = None
) -> List[LLMHistory]:
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
raise NotImplementedError
@asynccontextmanager
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
"""Get a database session."""
if not self.inited:
await self.initialize()
self.inited = True
async with self.AsyncSessionLocal() as session:
yield session
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_total_message_count(self) -> int:
"""获取总消息数"""
raise NotImplementedError
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
@abc.abstractmethod
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取基础统计数据(合并)"""
raise NotImplementedError
@abc.abstractmethod
def insert_atri_vision_data(self, vision_data: ATRIVision):
"""插入 ATRI 视觉数据"""
raise NotImplementedError
# New methods in v4.0.0
@abc.abstractmethod
def get_atri_vision_data(self) -> List[ATRIVision]:
"""获取 ATRI 视觉数据"""
raise NotImplementedError
async def insert_platform_stats(
self,
platform_id: str,
platform_type: str,
count: int = 1,
timestamp: datetime.datetime | None = None,
) -> None:
"""Insert a new platform statistic record."""
...
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(
self, url_or_path: str, id: str
) -> ATRIVision:
"""通过 url 或 path 获取 ATRI 视觉数据"""
raise NotImplementedError
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
...
@abc.abstractmethod
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
"""通过 user_id 和 cid 获取 Conversation"""
raise NotImplementedError
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
...
@abc.abstractmethod
def new_conversation(self, user_id: str, cid: str):
"""新建 Conversation"""
raise NotImplementedError
async def get_conversations(
self, user_id: str | None = None, platform_id: str | None = None
) -> list[ConversationV2]:
"""Get all conversations for a specific user and platform_id(optional).
@abc.abstractmethod
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新 Conversation"""
raise NotImplementedError
@abc.abstractmethod
def delete_conversation(self, user_id: str, cid: str):
"""删除 Conversation"""
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
"""更新 Conversation 标题"""
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID"""
raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
content is not included in the result.
"""
raise NotImplementedError
...
@abc.abstractmethod
def get_filtered_conversations(
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
"""Get a specific conversation by its ID."""
...
@abc.abstractmethod
async def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> list[ConversationV2]:
"""Get all conversations with pagination."""
...
@abc.abstractmethod
async def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
platform_ids: list[str] | None = None,
search_query: str = "",
**kwargs,
) -> tuple[list[ConversationV2], int]:
"""Get conversations filtered by platform IDs and search query."""
...
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
@abc.abstractmethod
async def create_conversation(
self,
user_id: str,
platform_id: str,
content: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
cid: str | None = None,
created_at: datetime.datetime | None = None,
updated_at: datetime.datetime | None = None,
) -> ConversationV2:
"""Create a new conversation."""
...
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
async def update_conversation(
self,
cid: str,
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
) -> None:
"""Update a conversation's history."""
...
@abc.abstractmethod
async def delete_conversation(self, cid: str) -> None:
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,
platform_id: str,
user_id: str,
content: list[dict],
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
"""Insert a new platform message history record."""
...
@abc.abstractmethod
async def delete_platform_message_offset(
self, platform_id: str, user_id: str, offset_sec: int = 86400
) -> None:
"""Delete platform message history records older than the specified offset."""
...
@abc.abstractmethod
async def get_platform_message_history(
self,
platform_id: str,
user_id: str,
page: int = 1,
page_size: int = 20,
) -> list[PlatformMessageHistory]:
"""Get platform message history for a specific user."""
...
@abc.abstractmethod
async def insert_attachment(
self,
path: str,
type: str,
mime_type: str,
):
"""Insert a new attachment record."""
...
@abc.abstractmethod
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
"""Get an attachment by its ID."""
...
@abc.abstractmethod
async def insert_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""Insert a new persona record."""
...
@abc.abstractmethod
async def get_persona_by_id(self, persona_id: str) -> Persona:
"""Get a persona by its ID."""
...
@abc.abstractmethod
async def get_personas(self) -> list[Persona]:
"""Get all personas for a specific bot."""
...
@abc.abstractmethod
async def update_persona(
self,
persona_id: str,
system_prompt: str | None = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona | None:
"""Update a persona's system prompt or begin dialogs."""
...
@abc.abstractmethod
async def delete_persona(self, persona_id: str) -> None:
"""Delete a persona by its ID."""
...
@abc.abstractmethod
async def insert_preference_or_update(
self, scope: str, scope_id: str, key: str, value: dict
) -> Preference:
"""Insert a new preference record."""
...
@abc.abstractmethod
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
"""Get a preference by scope ID and key."""
...
@abc.abstractmethod
async def get_preferences(
self, scope: str, scope_id: str | None = None, key: str | None = None
) -> list[Preference]:
"""Get all preferences for a specific scope ID or key."""
...
@abc.abstractmethod
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
"""Remove a preference by scope ID and key."""
...
@abc.abstractmethod
async def clear_preferences(self, scope: str, scope_id: str) -> None:
"""Clear all preferences for a specific scope ID."""
...
# @abc.abstractmethod
# async def insert_llm_message(
# self,
# cid: str,
# role: str,
# content: list,
# tool_calls: list = None,
# tool_call_id: str = None,
# parent_id: str = None,
# ) -> LLMMessage:
# """Insert a new LLM message into the conversation."""
# ...
# @abc.abstractmethod
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
# """Get all LLM messages for a specific conversation."""
# ...

View File

@@ -0,0 +1,64 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger, sp
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
migration_persona_data,
migration_preferences,
)
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
"""
检查是否需要进行数据库迁移
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4则需要进行迁移。
"""
data_v3_exists = os.path.exists(get_astrbot_data_path())
if not data_v3_exists:
return False
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_v4"
)
if migration_done:
return False
return True
async def do_migration_v4(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
astrbot_config: AstrBotConfig,
):
"""
执行数据库迁移
迁移旧的 webchat_conversation 表到新的 conversation 表。
迁移旧的 platform 到新的 platform_stats 表。
"""
if not await check_migration_needed_v4(db_helper):
return
logger.info("开始执行数据库迁移...")
# 执行会话表迁移
await migration_conversation_table(db_helper, platform_id_map)
# 执行人格数据迁移
await migration_persona_data(db_helper, astrbot_config)
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
# 执行偏好设置迁移
await migration_preferences(db_helper,platform_id_map)
# 执行平台统计表迁移
await migration_platform_table(db_helper, platform_id_map)
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_v4", True)
logger.info("数据库迁移完成。")

View File

@@ -0,0 +1,338 @@
import json
import datetime
from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from .shared_preferences_v3 import sp as sp_v3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
2. 迁移旧的 platform 到新的 platform_stats 表。
"""
def get_platform_id(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_id", old_platform_name)
def get_platform_type(
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
) -> str:
return platform_id_map.get(
old_platform_name,
{"platform_id": old_platform_name, "platform_type": old_platform_name},
).get("platform_type", old_platform_name)
async def migration_conversation_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
platform_id_map, session.platform_name
)
session.platform_id = platform_id # 更新平台名称为新的 ID
conv_v2 = ConversationV2(
user_id=str(session),
content=json.loads(conv.history) if conv.history else [],
platform_id=platform_id,
title=conv.title,
persona_id=conv.persona_id,
conversation_id=conv.cid,
created_at=datetime.datetime.fromtimestamp(conv.created_at),
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
)
dbsession.add(conv_v2)
except Exception as e:
logger.error(
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
async def migration_platform_table(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
secs_from_2023_4_10_to_now = (
datetime.datetime.now(datetime.timezone.utc)
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
).total_seconds()
offset_sec = int(secs_from_2023_4_10_to_now)
logger.info(f"迁移旧平台数据offset_sec: {offset_sec} 秒。")
stats = db_helper_v3.get_base_stats(offset_sec=offset_sec)
logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...")
platform_stats_v3 = stats.platform
if not platform_stats_v3:
logger.info("没有找到旧平台数据,跳过迁移。")
return
first_time_stamp = platform_stats_v3[0].timestamp
end_time_stamp = platform_stats_v3[-1].timestamp
start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时
end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时
idx = 0
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
total_buckets = (end_time - start_time) // 3600
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
if bucket_idx % 500 == 0:
progress = int((bucket_idx + 1) / total_buckets * 100)
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
cnt = 0
while (
idx < len(platform_stats_v3)
and platform_stats_v3[idx].timestamp < bucket_end
):
cnt += platform_stats_v3[idx].count
idx += 1
if cnt == 0:
continue
platform_id = get_platform_id(
platform_id_map, platform_stats_v3[idx].name
)
platform_type = get_platform_type(
platform_id_map, platform_stats_v3[idx].name
)
try:
await dbsession.execute(
text("""
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
VALUES (:timestamp, :platform_id, :platform_type, :count)
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
count = platform_stats.count + EXCLUDED.count
"""),
{
"timestamp": datetime.datetime.fromtimestamp(
bucket_end, tz=datetime.timezone.utc
),
"platform_id": platform_id,
"platform_type": platform_type,
"count": cnt,
},
)
except Exception:
logger.error(
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
exc_info=True,
)
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
async def migration_webchat_data(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
)
conversations, total_cnt = db_helper_v3.get_all_conversations(
page=1, page_size=10000000
)
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
continue
platform_id = "webchat"
history = json.loads(conv.history) if conv.history else []
for msg in history:
type_ = msg.get("type") # user type, "bot" or "user"
new_history = PlatformMessageHistory(
platform_id=platform_id,
user_id=conv.cid, # we use conv.cid as user_id for webchat
content=msg,
sender_id=type_,
sender_name=type_,
)
dbsession.add(new_history)
except Exception:
logger.error(
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
async def migration_persona_data(
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
):
"""
迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
"""
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
total_personas = len(v3_persona_config)
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
for idx, persona in enumerate(v3_persona_config):
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
progress = int((idx + 1) / total_personas * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
try:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
mood_prompt = ""
user_turn = True
for mood_dialog in mood_imitation_dialogs:
if user_turn:
mood_prompt += f"A: {mood_dialog}\n"
else:
mood_prompt += f"B: {mood_dialog}\n"
user_turn = not user_turn
system_prompt = persona.get("prompt", "")
if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
persona_new = await db_helper.insert_persona(
persona_id=persona["name"],
system_prompt=system_prompt,
begin_dialogs=begin_dialogs,
)
logger.info(
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
async def migration_preferences(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
):
# 1. global scope migration
keys = [
"inactivated_llm_tools",
"inactivated_plugins",
"curr_provider",
"curr_provider_tts",
"curr_provider_stt",
"alter_cmd",
]
for key in keys:
value = sp_v3.get(key)
if value is not None:
await sp.put_async("global", "global", key, value)
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
# 2. umo scope migration
session_conversation = sp_v3.get("session_conversation", default={})
for umo, conversation_id in session_conversation.items():
if not umo or not conversation_id:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
session_service_config = sp_v3.get("session_service_config", default={})
for umo, config in session_service_config.items():
if not umo or not config:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_service_config", config)
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
except Exception as e:
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
session_variables = sp_v3.get("session_variables", default={})
for umo, variables in session_variables.items():
if not umo or not variables:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
await sp.put_async("umo", str(session), "session_variables", variables)
except Exception as e:
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
session_provider_perf = sp_v3.get("session_provider_perf", default={})
for umo, perf in session_provider_perf.items():
if not umo or not perf:
continue
try:
session = MessageSesion.from_str(session_str=umo)
platform_id = get_platform_id(platform_id_map, session.platform_name)
session.platform_id = platform_id
for provider_type, provider_id in perf.items():
await sp.put_async(
"umo", str(session), f"provider_perf_{provider_type}", provider_id
)
logger.info(
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
)
except Exception as e:
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)

View File

@@ -0,0 +1,45 @@
import json
import os
from typing import TypeVar
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
try:
with open(self.path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
os.remove(self.path)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
sp = SharedPreferences()

View File

@@ -0,0 +1,493 @@
import sqlite3
import time
from astrbot.core.db.po import Platform, Stats
from typing import Tuple, List, Dict, Any
from dataclasses import dataclass
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
INIT_SQL = """
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';
"""
class SQLiteDatabase():
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path
sql = INIT_SQL
# 初始化数据库
self.conn = self._get_conn(self.db_path)
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
"""
PRAGMA table_info(webchat_conversation)
"""
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
"""
)
self.conn.commit()
if not has_persona_id:
c.execute(
"""
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
"""
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: Tuple = None):
conn = self.conn
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
conn = self._get_conn(self.db_path)
c = conn.cursor()
if params:
c.execute(sql, params)
c.close()
else:
c.execute(sql)
c.close()
conn.commit()
def insert_platform_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def insert_llm_metrics(self, metrics: dict):
for k, v in metrics.items():
self._exec_sql(
"""
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
""",
(k, v, int(time.time())),
)
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM platform
"""
+ where_clause
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform=platform)
def get_total_message_count(self) -> int:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT SUM(count) FROM platform
"""
)
res = c.fetchone()
c.close()
return res[0]
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT name, SUM(count), timestamp FROM platform
"""
+ where_clause
+ " GROUP BY name"
)
platform = []
for row in c.fetchall():
platform.append(Platform(*row))
c.close()
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
res = c.fetchone()
c.close()
if not res:
return
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
"""
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
""",
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
"""
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
""",
(user_id,),
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
)
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
"""更新对话,并且同时更新时间"""
updated_at = int(time.time())
self._exec_sql(
"""
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
""",
(history, updated_at, user_id, cid),
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
""",
(title, user_id, cid),
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
"""
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
""",
(persona_id, user_id, cid),
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
"""
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
""",
(user_id, cid),
)
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()

View File

@@ -1,7 +1,233 @@
"""指标数据"""
import uuid
from datetime import datetime, timezone
from dataclasses import dataclass, field
from typing import List
from sqlmodel import (
SQLModel,
Text,
JSON,
UniqueConstraint,
Field,
)
from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
"""This class represents the statistics of bot usage across different platforms.
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
platform_id: str = Field(nullable=False)
platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc.
count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"timestamp",
"platform_id",
"platform_type",
name="uix_platform_stats",
),
)
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
inner_conversation_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conversation_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
content: Optional[list] = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
title: Optional[str] = Field(default=None, max_length=255)
persona_id: Optional[str] = Field(default=None)
__table_args__ = (
UniqueConstraint(
"conversation_id",
name="uix_conversation_id",
),
)
class Persona(SQLModel, table=True):
"""Persona is a set of instructions for LLMs to follow.
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
persona_id: str = Field(max_length=255, nullable=False)
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
tools: Optional[list] = Field(default=None, sa_type=JSON)
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"persona_id",
name="uix_persona_id",
),
)
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences"
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
scope: str = Field(nullable=False)
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
scope_id: str = Field(nullable=False)
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
key: str = Field(nullable=False)
value: dict = Field(sa_type=JSON, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"scope",
"scope_id",
"key",
name="uix_preference_scope_scope_id_key",
),
)
class PlatformMessageHistory(SQLModel, table=True):
"""This class represents the message history for a specific platform.
It is used to store messages that are not LLM-generated, such as user messages
or platform-specific messages.
"""
__tablename__ = "platform_message_history"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
sender_name: Optional[str] = Field(
default=None
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class Attachment(SQLModel, table=True):
"""This class represents attachments for messages in AstrBot.
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments"
inner_attachment_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
attachment_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
mime_type: str = Field(nullable=False) # MIME type of the file
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"attachment_id",
name="uix_attachment_id",
),
)
@dataclass
class Conversation:
"""LLM 对话类
对于 WebChathistory 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
在 v4.0.0 版本及之后WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中,
"""
platform_id: str
user_id: str
cid: str
"""对话 ID, 是 uuid 格式的字符串"""
history: str = ""
"""字符串格式的对话列表。"""
title: str | None = ""
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
class Personality(TypedDict):
"""LLM 人格类。
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
# ====
# Deprecated, and will be removed in future versions.
# ====
@dataclass
@@ -13,77 +239,6 @@ class Platform:
timestamp: int
@dataclass
class Provider:
"""供应商使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Plugin:
"""插件使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Command:
"""命令使用统计数据"""
name: str
count: int
timestamp: int
@dataclass
class Stats:
platform: List[Platform] = field(default_factory=list)
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
@dataclass
class LLMHistory:
"""LLM 聊天时持久化的信息"""
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision:
"""Deprecated"""
id: str
url_or_path: str
caption: str
is_meme: bool
keywords: List[str]
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
@dataclass
class Conversation:
"""LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
"""
user_id: str
cid: str
history: str = ""
"""字符串格式的列表。"""
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
platform: list[Platform] = field(default_factory=list)

File diff suppressed because it is too large Load Diff

View File

@@ -1,50 +0,0 @@
CREATE TABLE IF NOT EXISTS platform(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS plugin(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS command(
name VARCHAR(32),
count INTEGER,
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
PRAGMA encoding = 'UTF-8';

View File

@@ -5,6 +5,7 @@ from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
class FaissVecDB(BaseVecDB):
@@ -17,6 +18,7 @@ class FaissVecDB(BaseVecDB):
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
rerank_provider: RerankProvider | None = None,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
@@ -26,11 +28,14 @@ class FaissVecDB(BaseVecDB):
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
@@ -53,7 +58,12 @@ class FaissVecDB(BaseVecDB):
return int_id
async def retrieve(
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
self,
query: str,
k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""
搜索最相似的文档。
@@ -62,6 +72,7 @@ class FaissVecDB(BaseVecDB):
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。
metadata_filters (dict): 元数据过滤器
Returns:
@@ -72,7 +83,6 @@ class FaissVecDB(BaseVecDB):
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
# TODO: rerank
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
@@ -83,7 +93,7 @@ class FaissVecDB(BaseVecDB):
)
if not fetched_docs:
return []
result_docs = []
result_docs: list[Result] = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
@@ -93,7 +103,20 @@ class FaissVecDB(BaseVecDB):
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
return result_docs[:k]
top_k_results = result_docs[:k]
if rerank and self.rerank_provider:
documents = [doc.data["text"] for doc in top_k_results]
reranked_results = await self.rerank_provider.rerank(query, documents)
reranked_results = sorted(
reranked_results, key=lambda x: x.relevance_score, reverse=True
)
top_k_results = [
top_k_results[reranked_result.index] for reranked_result in reranked_results
]
return top_k_results
async def delete(self, doc_id: int):
"""

View File

@@ -16,30 +16,32 @@ from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class EventBus:
"""事件总线: 用于处理事件的分发和处理
"""用于处理事件的分发和处理"""
维护一个异步队列, 来接受各种消息事件
"""
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
def __init__(
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
):
self.event_queue = event_queue # 事件队列
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr
async def dispatch(self):
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
while True:
event: AstrMessageEvent = (
await self.event_queue.get()
) # 从事件队列中获取新的事件
self._print_event(event) # 打印日志
asyncio.create_task(
self.pipeline_scheduler.execute(event)
) # 创建新的异步任务来执行管道调度器的处理逻辑
event: AstrMessageEvent = await self.event_queue.get()
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent):
def _print_event(self, event: AstrMessageEvent, conf_name: str):
"""用于记录事件信息
Args:
@@ -48,10 +50,10 @@ class EventBus:
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name():
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else:
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
)

View File

@@ -2,6 +2,8 @@ import asyncio
import os
import uuid
import time
from urllib.parse import urlparse, unquote
import platform
class FileTokenService:
@@ -15,7 +17,9 @@ class FileTokenService:
async def _cleanup_expired_tokens(self):
"""清理过期的令牌"""
now = time.time()
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
expired_tokens = [
token for token, (_, expire) in self.staged_files.items() if expire < now
]
for token in expired_tokens:
self.staged_files.pop(token, None)
@@ -32,15 +36,35 @@ class FileTokenService:
Raises:
FileNotFoundError: 当路径不存在时抛出
"""
# 处理 file:///
try:
parsed_uri = urlparse(file_path)
if parsed_uri.scheme == "file":
local_path = unquote(parsed_uri.path)
if platform.system() == "Windows" and local_path.startswith("/"):
local_path = local_path[1:]
else:
# 如果没有 file:/// 前缀,则认为是普通路径
local_path = file_path
except Exception:
# 解析失败时,按原路径处理
local_path = file_path
async with self.lock:
await self._cleanup_expired_tokens()
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
if not os.path.exists(local_path):
raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})"
)
file_token = str(uuid.uuid4())
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
self.staged_files[file_token] = (file_path, expire_time)
expire_time = time.time() + (
timeout if timeout is not None else self.default_timeout
)
# 存储转换后的真实路径
self.staged_files[file_token] = (local_path, expire_time)
return file_token
async def handle_file(self, file_token: str) -> str:

View File

@@ -22,6 +22,7 @@ class InitialLoader:
self.db = db
self.logger = logger
self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
@@ -35,8 +36,10 @@ class InitialLoader:
core_task = core_lifecycle.start()
webui_dir = self.webui_dir
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
)
task = asyncio.gather(
core_task, self.dashboard_server.run()

View File

@@ -96,8 +96,6 @@ class LogBroker:
Queue: 订阅者的队列, 可用于接收日志消息
"""
q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
self.subscribers.append(q)
return q

View File

@@ -37,7 +37,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
class ComponentType(Enum):
class ComponentType(str, Enum):
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
@@ -108,7 +108,7 @@ class BaseMessageComponent(BaseModel):
class Plain(BaseMessageComponent):
type: ComponentType = "Plain"
type = ComponentType.Plain
text: str
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
@@ -128,8 +128,9 @@ class Plain(BaseMessageComponent):
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent):
type: ComponentType = "Face"
type = ComponentType.Face
id: int
def __init__(self, **_):
@@ -137,7 +138,7 @@ class Face(BaseMessageComponent):
class Record(BaseMessageComponent):
type: ComponentType = "Record"
type = ComponentType.Record
file: T.Optional[str] = ""
magic: T.Optional[bool] = False
url: T.Optional[str] = ""
@@ -164,19 +165,24 @@ class Record(BaseMessageComponent):
return Record(file=url, **_)
raise Exception("not a valid url")
@staticmethod
def fromBase64(bs64_data: str, **_):
return Record(file=f"base64://{bs64_data}", **_)
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if self.file and self.file.startswith("file:///"):
file_path = self.file[8:]
return file_path
elif self.file and self.file.startswith("http"):
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
return self.file[8:]
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file and self.file.startswith("base64://"):
elif self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -185,8 +191,7 @@ class Record(BaseMessageComponent):
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
file_path = self.file
return os.path.abspath(file_path)
return os.path.abspath(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
@@ -197,12 +202,14 @@ class Record(BaseMessageComponent):
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if self.file and self.file.startswith("file:///"):
if not self.file:
raise Exception(f"not a valid file: {self.file}")
if self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file and self.file.startswith("http"):
elif self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file and self.file.startswith("base64://"):
elif self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
@@ -236,7 +243,7 @@ class Record(BaseMessageComponent):
class Video(BaseMessageComponent):
type: ComponentType = "Video"
type = ComponentType.Video
file: str
cover: T.Optional[str] = ""
c: T.Optional[int] = 2
@@ -322,7 +329,7 @@ class Video(BaseMessageComponent):
class At(BaseMessageComponent):
type: ComponentType = "At"
type = ComponentType.At
qq: T.Union[int, str] # 此处str为all时代表所有人
name: T.Optional[str] = ""
@@ -344,28 +351,28 @@ class AtAll(At):
class RPS(BaseMessageComponent): # TODO
type: ComponentType = "RPS"
type = ComponentType.RPS
def __init__(self, **_):
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
type: ComponentType = "Dice"
type = ComponentType.Dice
def __init__(self, **_):
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
type: ComponentType = "Shake"
type = ComponentType.Shake
def __init__(self, **_):
super().__init__(**_)
class Anonymous(BaseMessageComponent): # TODO
type: ComponentType = "Anonymous"
type = ComponentType.Anonymous
ignore: T.Optional[bool] = False
def __init__(self, **_):
@@ -373,7 +380,7 @@ class Anonymous(BaseMessageComponent): # TODO
class Share(BaseMessageComponent):
type: ComponentType = "Share"
type = ComponentType.Share
url: str
title: str
content: T.Optional[str] = ""
@@ -384,7 +391,7 @@ class Share(BaseMessageComponent):
class Contact(BaseMessageComponent): # TODO
type: ComponentType = "Contact"
type = ComponentType.Contact
_type: str # type 字段冲突
id: T.Optional[int] = 0
@@ -393,7 +400,7 @@ class Contact(BaseMessageComponent): # TODO
class Location(BaseMessageComponent): # TODO
type: ComponentType = "Location"
type = ComponentType.Location
lat: float
lon: float
title: T.Optional[str] = ""
@@ -404,7 +411,7 @@ class Location(BaseMessageComponent): # TODO
class Music(BaseMessageComponent):
type: ComponentType = "Music"
type = ComponentType.Music
_type: str
id: T.Optional[int] = 0
url: T.Optional[str] = ""
@@ -421,7 +428,7 @@ class Music(BaseMessageComponent):
class Image(BaseMessageComponent):
type: ComponentType = "Image"
type = ComponentType.Image
file: T.Optional[str] = ""
_type: T.Optional[str] = ""
subType: T.Optional[int] = 0
@@ -464,14 +471,15 @@ class Image(BaseMessageComponent):
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
image_file_path = url[8:]
return image_file_path
elif url and url.startswith("http"):
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
return url[8:]
elif url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url and url.startswith("base64://"):
elif url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
@@ -480,8 +488,7 @@ class Image(BaseMessageComponent):
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
image_file_path = url
return os.path.abspath(image_file_path)
return os.path.abspath(url)
else:
raise Exception(f"not a valid file: {url}")
@@ -492,13 +499,15 @@ class Image(BaseMessageComponent):
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
url = self.url or self.file
if not url:
raise ValueError("No valid file or URL provided")
if url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url and url.startswith("http"):
elif url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url and url.startswith("base64://"):
elif url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
@@ -532,7 +541,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
type = ComponentType.Reply
id: T.Union[str, int]
"""所引用的消息 ID"""
chain: T.Optional[T.List["BaseMessageComponent"]] = []
@@ -558,7 +567,7 @@ class Reply(BaseMessageComponent):
class RedBag(BaseMessageComponent):
type: ComponentType = "RedBag"
type = ComponentType.RedBag
title: str
def __init__(self, **_):
@@ -566,7 +575,7 @@ class RedBag(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ""
type: str = ComponentType.Poke
id: T.Optional[int] = 0
qq: T.Optional[int] = 0
@@ -576,7 +585,7 @@ class Poke(BaseMessageComponent):
class Forward(BaseMessageComponent):
type: ComponentType = "Forward"
type = ComponentType.Forward
id: str
def __init__(self, **_):
@@ -586,7 +595,7 @@ class Forward(BaseMessageComponent):
class Node(BaseMessageComponent):
"""群合并转发消息"""
type: ComponentType = "Node"
type = ComponentType.Node
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
@@ -638,7 +647,7 @@ class Node(BaseMessageComponent):
class Nodes(BaseMessageComponent):
type: ComponentType = "Nodes"
type = ComponentType.Nodes
nodes: T.List[Node]
def __init__(self, nodes: T.List[Node], **_):
@@ -664,7 +673,7 @@ class Nodes(BaseMessageComponent):
class Xml(BaseMessageComponent):
type: ComponentType = "Xml"
type = ComponentType.Xml
data: str
resid: T.Optional[int] = 0
@@ -673,7 +682,7 @@ class Xml(BaseMessageComponent):
class Json(BaseMessageComponent):
type: ComponentType = "Json"
type = ComponentType.Json
data: T.Union[str, dict]
resid: T.Optional[int] = 0
@@ -684,7 +693,7 @@ class Json(BaseMessageComponent):
class CardImage(BaseMessageComponent):
type: ComponentType = "CardImage"
type = ComponentType.CardImage
file: str
cache: T.Optional[bool] = True
minwidth: T.Optional[int] = 400
@@ -703,7 +712,7 @@ class CardImage(BaseMessageComponent):
class TTS(BaseMessageComponent):
type: ComponentType = "TTS"
type = ComponentType.TTS
text: str
def __init__(self, **_):
@@ -711,7 +720,7 @@ class TTS(BaseMessageComponent):
class Unknown(BaseMessageComponent):
type: ComponentType = "Unknown"
type = ComponentType.Unknown
text: str
def toString(self):
@@ -723,7 +732,7 @@ class File(BaseMessageComponent):
文件消息段
"""
type: ComponentType = "File"
type = ComponentType.File
name: T.Optional[str] = "" # 名字
file_: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
@@ -853,7 +862,7 @@ class File(BaseMessageComponent):
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"
type = ComponentType.WechatEmoji
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""

183
astrbot/core/persona_mgr.py Normal file
View File

@@ -0,0 +1,183 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, Personality
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.platform.message_session import MessageSession
from astrbot import logger
DEFAULT_PERSONALITY = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
begin_dialogs=[],
mood_imitation_dialogs=[],
tools=None,
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed="",
)
class PersonaManager:
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
self.db = db_helper
self.acm = acm
default_ps = acm.default_conf.get("provider_settings", {})
self.default_persona: str = default_ps.get("default_personality", "default")
self.personas: list[Persona] = []
self.selected_default_persona: Persona | None = None
self.personas_v3: list[Personality] = []
self.selected_default_persona_v3: Personality | None = None
self.persona_v3_config: list[dict] = []
async def initialize(self):
self.personas = await self.get_all_personas()
self.get_v3_persona_data()
logger.info(f"已加载 {len(self.personas)} 个人格。")
async def get_persona(self, persona_id: str):
"""获取指定 persona 的信息"""
persona = await self.db.get_persona_by_id(persona_id)
if not persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona
async def get_default_persona_v3(
self, umo: str | MessageSession | None = None
) -> Personality:
"""获取默认 persona"""
cfg = self.acm.get_conf(umo)
default_persona_id = cfg.get("provider_settings", {}).get(
"default_personality", "default"
)
if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
try:
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
except Exception:
return DEFAULT_PERSONALITY
async def delete_persona(self, persona_id: str):
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
await self.db.delete_persona(persona_id)
self.personas = [p for p in self.personas if p.persona_id != persona_id]
self.get_v3_persona_data()
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
if not existing_persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
persona = await self.db.update_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
if persona:
for i, p in enumerate(self.personas):
if p.persona_id == persona_id:
self.personas[i] = persona
break
self.get_v3_persona_data()
return persona
async def get_all_personas(self) -> list[Persona]:
"""获取所有 personas"""
return await self.db.get_personas()
async def create_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} already exists.")
new_persona = await self.db.insert_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
self.personas.append(new_persona)
self.get_v3_persona_data()
return new_persona
def get_v3_persona_data(
self,
) -> tuple[list[dict], list[Personality], Personality]:
"""获取 AstrBot <4.0.0 版本的 persona 数据。
Returns:
- list[dict]: 包含 persona 配置的字典列表。
- list[Personality]: 包含 Personality 对象的列表。
- Personality: 默认选择的 Personality 对象。
"""
v3_persona_config = [
{
"prompt": persona.system_prompt,
"name": persona.persona_id,
"begin_dialogs": persona.begin_dialogs or [],
"mood_imitation_dialogs": [], # deprecated
"tools": persona.tools,
}
for persona in self.personas
]
personas_v3: list[Personality] = []
selected_default_persona: Personality | None = None
for persona_cfg in v3_persona_config:
begin_dialogs = persona_cfg.get("begin_dialogs", [])
bd_processed = []
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append(
{
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
}
)
user_turn = not user_turn
try:
persona = Personality(
**persona_cfg,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed="", # deprecated
)
if persona["name"] == self.default_persona:
selected_default_persona = persona
personas_v3.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not selected_default_persona and len(personas_v3) > 0:
# 默认选择第一个
selected_default_persona = personas_v3[0]
if not selected_default_persona:
selected_default_persona = DEFAULT_PERSONALITY
personas_v3.append(selected_default_persona)
self.personas_v3 = personas_v3
self.selected_default_persona_v3 = selected_default_persona
self.persona_v3_config = v3_persona_config
self.selected_default_persona = Persona(
persona_id=selected_default_persona["name"],
system_prompt=selected_default_persona["prompt"],
begin_dialogs=selected_default_persona["begin_dialogs"],
tools=selected_default_persona["tools"] or None,
)
return v3_persona_config, personas_v3, selected_default_persona

View File

@@ -1,25 +1,25 @@
from astrbot.core.message.message_event_result import (
MessageEventResult,
EventResultType,
MessageEventResult,
)
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
@@ -29,9 +29,9 @@ STAGES_ORDER = [
__all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PlatformCompatibilityStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",

View File

@@ -1,14 +1,7 @@
import inspect
import traceback
import typing as T
from dataclasses import dataclass
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from astrbot.core.star import PluginManager
from astrbot.api import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from .context_utils import call_handler, call_event_hook
@dataclass
@@ -17,91 +10,6 @@ class PipelineContext:
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
async def call_event_hook(
self,
event: AstrMessageEvent,
hook_type: EventType,
*args,
):
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
async def call_handler(
self,
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook

View File

@@ -0,0 +1,98 @@
import inspect
import traceback
import typing as T
from astrbot import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
async def call_event_hook(
event: AstrMessageEvent,
hook_type: EventType,
*args,
**kwargs,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
# """
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, plugins_name=event.plugins_name
)
for handler in handlers:
try:
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args, **kwargs)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()

View File

@@ -1,56 +0,0 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
@register_stage
class PlatformCompatibilityStage(Stage):
"""检查所有处理器的平台兼容性。
这个阶段会检查所有处理器是否在当前平台启用如果未启用则设置platform_compatible属性为False。
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化平台兼容性检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 获取当前平台ID
platform_id = event.get_platform_id()
# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
activated_handlers = []
# 标记不兼容的处理器
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
handler.platform_compatible = False
else:
# 确保处理器为平台兼容状态
handler.platform_compatible = True
# 更新已激活的处理器列表
event.set_extra("activated_handlers", activated_handlers)

View File

@@ -2,29 +2,288 @@
本地 Agent 模式的 LLM 调用 Stage
"""
import traceback
import copy
import asyncio
import copy
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
MessageChain,
)
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ProviderRequest,
)
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolSet, FunctionTool
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType
from astrbot.core import web_chat_back_queue
from ..agent_runner.tool_loop_agent import ToolLoopAgent
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext, call_event_hook, call_handler
from ..stage import Stage
from astrbot.core.provider.register import llm_tools
from astrbot.core.star.star_handler import star_map
from astrbot.core.astr_agent_context import AstrAgentContext
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
**kwargs: 函数调用的参数。
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if isinstance(tool, HandoffTool):
async for r in cls._execute_handoff(tool, run_context, **tool_args):
yield r
return
if tool.origin == "local":
async for r in cls._execute_local(tool, run_context, **tool_args):
yield r
return
elif tool.origin == "mcp":
async for r in cls._execute_mcp(tool, run_context, **tool_args):
yield r
return
raise Exception(f"Unknown function origin: {tool.origin}")
@classmethod
async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
input_ = tool_args.get("input", "agent")
agent_runner = AgentRunner()
# make toolset for the agent
tools = tool.agent.tools
if tools:
toolset = ToolSet()
for t in tools:
if isinstance(t, str):
_t = llm_tools.get_func(t)
if _t:
toolset.add_tool(_t)
elif isinstance(t, FunctionTool):
toolset.add_tool(t)
else:
toolset = None
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description,
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
)
astr_agent_ctx = AstrAgentContext(
provider=run_context.context.provider,
first_provider_request=run_context.context.first_provider_request,
curr_provider_request=request,
streaming=run_context.context.streaming,
)
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
await run_context.event.send(
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
)
await agent_runner.reset(
provider=run_context.context.provider,
request=request,
run_context=AgentContextWrapper(
context=astr_agent_ctx, event=run_context.event
),
tool_executor=FunctionToolExecutor(),
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
streaming=run_context.context.streaming,
)
async for _ in run_agent(agent_runner, 15, True):
pass
if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()
logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
result = (
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
)
text_content = mcp.types.TextContent(
type="text",
text=result,
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return
@classmethod
async def _execute_local(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not run_context.event:
raise ValueError("Event must be provided for local function tools.")
# 检查 tool 下有没有 run 方法
if not tool.handler and not hasattr(tool, "run"):
raise ValueError("Tool must have a valid handler or 'run' method.")
awaitable = tool.handler or getattr(tool, "run")
wrapper = call_handler(
event=run_context.event,
handler=awaitable,
**tool_args,
)
async for resp in wrapper:
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
@classmethod
async def _execute_mcp(
cls,
tool: FunctionTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(
name=tool.name,
arguments=tool_args,
)
if not res:
return
yield res
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
run_context.event, EventType.OnLLMResponseEvent, llm_response
)
MAIN_AGENT_HOOKS = MainAgentHooks()
async def run_agent(
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.event
while step_idx < max_step:
step_idx += 1
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
astr_event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
break
except Exception as e:
logger.error(traceback.format_exc())
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
class LLMRequestSubStage(Stage):
@@ -40,7 +299,9 @@ class LLMRequestSubStage(Stage):
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 10)
self.max_step: int = settings.get("max_agent_step", 30)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
for bwp in self.bot_wake_prefixs:
@@ -52,16 +313,47 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
if sel_provider and isinstance(sel_provider, str):
provider = _ctx.get_provider_by_id(sel_provider)
if not provider:
logger.error(f"未找到指定的提供商: {sel_provider}")
return provider
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def _get_session_conv(self, event: AstrMessageEvent):
umo = event.unified_msg_origin
conv_mgr = self.conv_manager
# 获取对话上下文
cid = await conv_mgr.get_curr_conversation_id(umo)
if not cid:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
return conversation
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
umo = event.unified_msg_origin
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
@@ -76,34 +368,20 @@ class LLMRequestSubStage(Stage):
else:
req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(
event.unified_msg_origin
)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
if not conversation:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
@@ -113,7 +391,8 @@ class LLMRequestSubStage(Stage):
return
# 执行请求 LLM 前事件钩子。
await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req)
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
@@ -146,65 +425,50 @@ class LLMRequestSubStage(Stage):
# fix messages
req.contexts = self.fix_messages(req.contexts)
# Call Agent
tool_loop_agent = ToolLoopAgent(
provider=provider,
event=event,
pipeline_ctx=self.ctx,
# check provider modalities
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
if req.image_urls:
provider_cfg = provider.provider_config.get("modalities", ["image"])
if "image" not in provider_cfg:
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
req.image_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
if "tool_use" not in provider_cfg:
logger.debug(
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
)
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
async def requesting():
step_idx = 0
while step_idx < self.max_step:
step_idx += 1
try:
async for resp in tool_loop_agent.step():
if resp.type == "tool_call_result":
continue # 跳过工具调用结果
if resp.type == "tool_call":
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if self.show_tool_use or event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
req.func_tool = None
# 插件可用性设置
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
if not self.streaming_response:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
)
event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if tool_loop_agent.done():
break
except Exception as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
)
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
astr_agent_ctx = AstrAgentContext(
provider=provider,
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=self.streaming_response,
)
if self.streaming_response:
@@ -212,11 +476,13 @@ class LLMRequestSubStage(Stage):
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting())
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use)
)
)
yield
if tool_loop_agent.done():
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
@@ -230,16 +496,18 @@ class LLMRequestSubStage(Stage):
)
)
else:
async for _ in requesting():
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
yield
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req))
asyncio.create_task(self._handle_webchat(event, req, provider))
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
@@ -249,17 +517,16 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:]
if not latest_pair:
return
provider = self.ctx.plugin_manager.context.get_using_provider()
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await provider.text_chat(
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
prompt=(
f"Please summarize the following query of user:\n"
f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `None`"
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
),
)
if llm_resp and llm_resp.completion_text:
@@ -267,28 +534,12 @@ class LLMRequestSubStage(Stage):
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
)
title = llm_resp.completion_text.strip()
if not title or "None" == title:
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
event.unified_msg_origin, title=title
)
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
# webchat adapter 中session_id 的格式是 f"webchat!{username}!{cid}"
# TODO: 优化 WebChat 适配器的对话管理
if event.session_id:
username, cid = event.session_id.split("!")[1:3]
db_helper = self.ctx.plugin_manager.context._db
db_helper.update_conversation_title(
user_id=username,
cid=cid,
unified_msg_origin=event.unified_msg_origin,
title=title,
)
web_chat_back_queue.put_nowait(
{
"type": "update_title",
"cid": cid,
"data": title,
}
conversation_id=req.conversation.cid,
)
async def _save_to_history(
@@ -305,6 +556,10 @@ class LLMRequestSubStage(Stage):
):
return
if not llm_response.completion_text and not req.tool_calls_result:
logger.debug("LLM 响应为空,不保存记录。")
return
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
@@ -321,7 +576,6 @@ class LLMRequestSubStage(Stage):
await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=messages
)
logger.debug(f"messages persisted: {messages}")
def fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""

View File

@@ -2,7 +2,7 @@
本地 Agent 模式的 AstrBot 插件调用 Stage
"""
from ...context import PipelineContext
from ...context import PipelineContext, call_handler
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -33,16 +33,6 @@ class StarRequestSubStage(Stage):
handlers_parsed_params = {}
for handler in activated_handlers:
# 检查处理器是否在当前平台兼容
if (
hasattr(handler, "platform_compatible")
and handler.platform_compatible is False
):
logger.debug(
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
)
continue
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
@@ -50,7 +40,7 @@ class StarRequestSubStage(Stage):
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
wrapper = self.ctx.call_handler(event, handler.handler, **params)
wrapper = call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果

View File

@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@register_stage
@@ -127,7 +128,7 @@ class RespondStage(Stage):
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_name()})")
logger.info(f"应用流式输出({event.get_platform_id()})")
await event.send_streaming(result.async_stream, use_fallback)
return
elif len(result.chain) > 0:
@@ -143,8 +144,6 @@ class RespondStage(Stage):
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
@@ -177,6 +176,8 @@ class RespondStage(Stage):
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
@@ -185,7 +186,6 @@ class RespondStage(Stage):
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
i = await self._calc_comp_interval(comp)
@@ -214,7 +214,7 @@ class RespondStage(Stage):
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
)
for handler in handlers:
try:

View File

@@ -3,11 +3,12 @@ import time
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import html_renderer, logger, file_token_service
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -35,6 +36,7 @@ class ResultDecorateStage(Stage):
self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
self.t2i_use_network = self.t2i_strategy == "remote"
self.t2i_active_template = ctx.astrbot_config["t2i_active_template"]
self.forward_threshold = ctx.astrbot_config["platform_settings"][
"forward_threshold"
@@ -63,9 +65,10 @@ class ResultDecorateStage(Stage):
]
self.content_safe_check_stage = None
if self.content_safe_check_reply:
for stage in registered_stages:
if stage.__class__.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage
for stage_cls in registered_stages:
if stage_cls.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx)
async def process(
self, event: AstrMessageEvent
@@ -97,7 +100,7 @@ class ResultDecorateStage(Stage):
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name
)
for handler in handlers:
try:
@@ -176,10 +179,12 @@ class ResultDecorateStage(Stage):
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
event.unified_msg_origin
)
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
and SessionServiceManager.should_process_tts_request(event)
):
new_chain = []
for comp in result.chain:
@@ -243,7 +248,10 @@ class ResultDecorateStage(Stage):
render_start = time.time()
try:
url = await html_renderer.render_t2i(
plain_str, return_url=True, use_network=self.t2i_use_network
plain_str,
return_url=True,
use_network=self.t2i_use_network,
template_name=self.t2i_active_template,
)
except BaseException:
logger.error("文本转图片失败,使用文本发送。")

View File

@@ -11,16 +11,17 @@ class PipelineScheduler:
def __init__(self, context: PipelineContext):
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
key=lambda x: STAGES_ORDER.index(x.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
self.stages = [] # 存储阶段实例
async def initialize(self):
"""初始化管道调度器时, 初始化所有阶段"""
for stage in registered_stages:
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
for stage_cls in registered_stages:
stage_instance = stage_cls() # 创建实例
await stage_instance.initialize(self.ctx)
self.stages.append(stage_instance)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
"""依次执行各个阶段
@@ -29,9 +30,9 @@ class PipelineScheduler:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
for i in range(from_stage, len(self.stages)):
stage = self.stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
coroutine = stage.process(
event
) # 调用阶段的process方法, 返回协程或者异步生成器
@@ -73,7 +74,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if not event._has_send_oper and event.get_platform_name() == "webchat":
if event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -0,0 +1,22 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core import logger
@register_stage
class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
event.stop_event()

View File

@@ -1,15 +1,15 @@
from __future__ import annotations
import abc
from typing import List, AsyncGenerator, Union
from typing import List, AsyncGenerator, Union, Type
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
def register_stage(cls):
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
registered_stages.append(cls())
registered_stages.append(cls)
return cls

View File

@@ -1,13 +1,16 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot import logger
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage
@register_stage
@@ -109,8 +112,17 @@ class WakingCheckStage(Stage):
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
# 将 plugins_name 设置到 event 中
enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"])
if enabled_plugins_name == ["*"]:
# 如果是 *,则表示所有插件都启用
event.plugins_name = None
else:
event.plugins_name = enabled_plugins_name
logger.debug(f"enabled_plugins_name: {enabled_plugins_name}")
for handler in star_handlers_registry.get_handlers_by_event_type(
EventType.AdapterMessageEvent
EventType.AdapterMessageEvent, plugins_name=event.plugins_name
):
# filter 需满足 AND 逻辑关系
passed = True
@@ -164,7 +176,12 @@ class WakingCheckStage(Stage):
"parsed_params"
)
event.clear_extra()
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
event, activated_handlers
)
event.set_extra("activated_handlers", activated_handlers)
event.set_extra("handlers_parsed_params", handlers_parsed_params)

View File

@@ -3,9 +3,10 @@ import asyncio
import re
import hashlib
import uuid
from dataclasses import dataclass
from typing import List, Union, Optional, AsyncGenerator
from astrbot import logger
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
Plain,
@@ -23,21 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
@dataclass
class MessageSesion:
platform_name: str
message_type: MessageType
session_id: str
def __str__(self):
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
@staticmethod
def from_str(session_str: str):
platform_name, message_type, session_id = session_str.split(":")
return MessageSesion(platform_name, MessageType(message_type), session_id)
from .message_session import MessageSession, MessageSesion # noqa
class AstrMessageEvent(abc.ABC):
@@ -64,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
platform_name=platform_meta.id,
message_type=message_obj.type,
session_id=session_id,
)
@@ -78,13 +65,23 @@ class AstrMessageEvent(abc.ABC):
self.call_llm = False
"""是否在此消息事件中禁止默认的 LLM 请求"""
self.plugins_name: list[str] | None = None
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
# back_compability
self.platform = platform_meta
def get_platform_name(self):
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
NOTE: 用户可能会同时运行多个相同类型的平台适配器。"""
return self.platform_meta.name
def get_platform_id(self):
"""获取这个事件所属的平台的 ID。
NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。
"""
return self.platform_meta.id
def get_message_str(self) -> str:
@@ -188,6 +185,7 @@ class AstrMessageEvent(abc.ABC):
"""
清除额外的信息。
"""
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
self._extras.clear()
def is_private_chat(self) -> bool:
@@ -227,7 +225,7 @@ class AstrMessageEvent(abc.ABC):
):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊。
Fallback仅支持 aiocqhttp, gewechat
Fallback仅支持 aiocqhttp。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
@@ -419,7 +417,6 @@ class AstrMessageEvent(abc.ABC):
适配情况:
- gewechat
- aiocqhttp(OneBotv11)
"""
...

View File

@@ -6,6 +6,7 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType
from .sources.webchat.webchat_adapter import WebChatAdapter
@@ -18,6 +19,9 @@ class PlatformManager:
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
这个配置中的 unique_session 需要特殊处理,
约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
self.event_queue = event_queue
async def initialize(self):
@@ -58,24 +62,26 @@ class PlatformManager:
from .sources.qqofficial_webhook.qo_webhook_adapter import (
QQOfficialWebhookPlatformAdapter, # noqa: F401
)
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import (
GewechatPlatformAdapter, # noqa: F401
)
case "wechatpadpro":
from .sources.wechatpadpro.wechatpadpro_adapter import (
WeChatPadProAdapter, # noqa: F401
)
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
from .sources.lark.lark_adapter import (
LarkPlatformAdapter,
) # noqa: F401
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram":
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
from .sources.telegram.tg_adapter import (
TelegramPlatformAdapter,
) # noqa: F401
case "wecom":
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
from .sources.wecom.wecom_adapter import (
WecomPlatformAdapter,
) # noqa: F401
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa
@@ -86,6 +92,10 @@ class PlatformManager:
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
case "satori":
from .sources.satori.satori_adapter import (
SatoriPlatformAdapter,
) # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
@@ -114,6 +124,17 @@ class PlatformManager:
)
)
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnPlatformLoadedEvent
)
for handler in handlers:
try:
logger.info(
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler()
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(self, task: asyncio.Task):
try:

View File

@@ -0,0 +1,28 @@
from astrbot.core.platform.message_type import MessageType
from dataclasses import dataclass
@dataclass
class MessageSession:
"""描述一条消息在 AstrBot 中对应的会话的唯一标识。
如果您需要实例化 MessageSession请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。"""
platform_name: str
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
message_type: MessageType
session_id: str
platform_id: str = None
def __str__(self):
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
def __post_init__(self):
self.platform_id = self.platform_name
@staticmethod
def from_str(session_str: str):
platform_id, message_type, session_id = session_str.split(":")
return MessageSession(platform_id, MessageType(message_type), session_id)
MessageSesion = MessageSession # back compatibility

View File

@@ -5,7 +5,7 @@ from asyncio import Queue
from .platform_metadata import PlatformMetadata
from .astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from .astr_message_event import MessageSesion
from .message_session import MessageSesion
from astrbot.core.utils.metrics import Metric

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
@dataclass
class PlatformMetadata:
name: str
"""平台的名称"""
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
id: str = None

View File

@@ -1,7 +1,7 @@
import asyncio
import re
from typing import AsyncGenerator, Dict, List
from aiocqhttp import CQHttp
from aiocqhttp import CQHttp, Event
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
@@ -58,50 +58,97 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
ret.append(d)
return ret
async def send(self, message: MessageChain):
@classmethod
async def _dispatch_send(
cls,
bot: CQHttp,
event: Event | None,
is_group: bool,
session_id: str,
messages: list[dict],
):
# session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None
if is_group and isinstance(session_id, int):
await bot.send_group_msg(group_id=session_id, message=messages)
elif not is_group and isinstance(session_id, int):
await bot.send_private_msg(user_id=session_id, message=messages)
elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages)
else:
raise ValueError(
f"无法发送消息:缺少有效的数字 session_id({session_id}) 或 event({event})"
)
@classmethod
async def send_message(
cls,
bot: CQHttp,
message_chain: MessageChain,
event: Event | None = None,
is_group: bool = False,
session_id: str = None,
):
"""发送消息至 QQ 协议端aiocqhttp
Args:
bot (CQHttp): aiocqhttp 机器人实例
message_chain (MessageChain): 要发送的消息链
event (Event | None, optional): aiocqhttp 事件对象.
is_group (bool, optional): 是否为群消息.
session_id (str | None, optional): 会话 ID群号或 QQ 号
"""
# 转发消息、文件消息不能和普通消息混在一起发送
send_one_by_one = any(
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
)
if send_one_by_one:
for seg in message.chain:
if not send_one_by_one:
ret = await cls._parse_onebot_json(message_chain)
if not ret:
return
await cls._dispatch_send(bot, event, is_group, session_id, ret)
return
for seg in message_chain.chain:
if isinstance(seg, (Node, Nodes)):
# 合并转发消息
if isinstance(seg, Node):
nodes = Nodes([seg])
seg = nodes
payload = await seg.to_dict()
if self.get_group_id():
payload["group_id"] = self.get_group_id()
await self.bot.call_action("send_group_forward_msg", **payload)
if is_group:
payload["group_id"] = session_id
await bot.call_action("send_group_forward_msg", **payload)
else:
payload["user_id"] = self.get_sender_id()
await self.bot.call_action(
"send_private_forward_msg", **payload
)
payload["user_id"] = session_id
await bot.call_action("send_private_forward_msg", **payload)
elif isinstance(seg, File):
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
await self.bot.send(
self.message_obj.raw_message,
[d],
)
d = await cls._from_segment_to_dict(seg)
await cls._dispatch_send(bot, event, is_group, session_id, [d])
else:
await self.bot.send(
self.message_obj.raw_message,
await AiocqhttpMessageEvent._parse_onebot_json(
MessageChain([seg])
),
)
messages = await cls._parse_onebot_json(MessageChain([seg]))
if not messages:
continue
await cls._dispatch_send(bot, event, is_group, session_id, messages)
await asyncio.sleep(0.5)
else:
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if not ret:
return
await self.bot.send(self.message_obj.raw_message, ret)
async def send(self, message: MessageChain):
"""发送消息"""
event = getattr(self.message_obj, "raw_message", None)
is_group = bool(self.get_group_id())
session_id = self.get_group_id() if is_group else self.get_sender_id()
await self.send_message(
bot=self.bot,
message_chain=message,
event=event, # 不强制要求一定是 Event
is_group=is_group,
session_id=session_id,
)
await super().send(message)
async def send_streaming(

View File

@@ -83,19 +83,18 @@ class AiocqhttpAdapter(Platform):
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value:
case MessageType.GROUP_MESSAGE.value:
if "_" in session.session_id:
# 独立会话
_, group_id = session.session_id.split("_")
await self.bot.send_group_msg(group_id=group_id, message=ret)
is_group = session.message_type == MessageType.GROUP_MESSAGE
if is_group:
session_id = session.session_id.split("_")[-1]
else:
await self.bot.send_group_msg(
group_id=session.session_id, message=ret
session_id = session.session_id
await AiocqhttpMessageEvent.send_message(
bot=self.bot,
message_chain=message_chain,
event=None, # 这里不需要 event因为是通过 session 发送的
is_group=is_group,
session_id=session_id,
)
case MessageType.FRIEND_MESSAGE.value:
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
async def convert_message(self, event: Event) -> AstrBotMessage:
@@ -273,8 +272,14 @@ class AiocqhttpAdapter(Platform):
)
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
reply_event_data["post_type"] = "message"
new_event = Event.from_payload(reply_event_data)
if not new_event:
logger.error(
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
)
continue
abm_reply = await self._convert_handle_message_event(
Event.from_payload(reply_event_data), get_reply=False
new_event, get_reply=False
)
reply_seg = Reply(
@@ -303,10 +308,19 @@ class AiocqhttpAdapter(Platform):
continue
at_info = await self.bot.call_action(
action="get_stranger_info",
action="get_group_member_info",
group_id=event.group_id,
user_id=int(m["data"]["qq"]),
no_cache=False,
)
if at_info:
nickname = at_info.get("card", "")
if nickname == "":
at_info = await self.bot.call_action(
action="get_stranger_info",
user_id=int(m["data"]["qq"]),
no_cache=False,
)
nickname = at_info.get("nick", "") or at_info.get("nickname", "")
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}

View File

@@ -26,7 +26,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"AstrBot",
segment.text,
segment.text,
self.message_obj.raw_message,
)

View File

@@ -1,812 +0,0 @@
import asyncio
import base64
import datetime
import os
import re
import uuid
import threading
import aiohttp
import anyio
import quart
from astrbot.api import logger, sp
from astrbot.api.message_components import Plain, Image, At, Record, Video
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.core.utils.io import download_image_by_url
from .downloader import GeweDownloader
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
try:
from .xml_data_parser import GeweDataParser
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
class SimpleGewechatClient:
"""针对 Gewechat 的简单实现。
@author: Soulter
@website: https://github.com/Soulter
"""
def __init__(
self,
base_url: str,
nickname: str,
host: str,
port: int,
event_queue: asyncio.Queue,
):
self.base_url = base_url
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
self.token = None
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.server = quart.Quart(__name__)
self.server.add_url_rule(
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
)
self.server.add_url_rule(
"/astrbot-gewechat/file/<file_token>",
view_func=self._handle_file,
methods=["GET"],
)
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
self.userrealnames = {}
self.shutdown_event = asyncio.Event()
self.staged_files = {}
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
self.lock = asyncio.Lock()
async def get_token_id(self):
"""获取 Gewechat Token。"""
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json()
self.token = json_blob["data"]
logger.info(f"获取到 Gewechat Token: {self.token}")
self.headers = {"X-GEWE-TOKEN": self.token}
async def _convert(self, data: dict) -> AstrBotMessage:
if "TypeName" in data:
type_name = data["TypeName"]
elif "type_name" in data:
type_name = data["type_name"]
else:
raise Exception("无法识别的消息类型")
# 以下没有业务处理,只是避免控制台打印太多的日志
if type_name == "ModContacts":
logger.info("gewechat下发ModContacts消息通知。")
return
if type_name == "DelContacts":
logger.info("gewechat下发DelContacts消息通知。")
return
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
d = None
if "Data" in data:
d = data["Data"]
elif "data" in data:
d = data["data"]
if not d:
logger.warning(f"消息不含 data 字段: {data}")
return
if "CreateTime" in d:
# 得到系统 UTF+8 的 ts
tz_offset = datetime.timedelta(hours=8)
tz = datetime.timezone(tz_offset)
ts = datetime.datetime.now(tz).timestamp()
create_time = d["CreateTime"]
if create_time < ts - 30:
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
return
abm = AstrBotMessage()
from_user_name = d["FromUserName"]["string"] # 消息来源
d["to_wxid"] = from_user_name # 用于发信息
abm.message_id = str(d.get("MsgId"))
abm.session_id = from_user_name
abm.self_id = data["Wxid"] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d["Content"]["string"] # 消息内容
at_me = False
at_wxids = []
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(":\n")
user_id = _t[0]
content = _t[1]
# at
msg_source = d["MsgSource"]
if "\u2005" in content:
# at
# content = content.split('\u2005')[1]
content = re.sub(r"@[^\u2005]*\u2005", "", content)
at_wxids = re.findall(
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
msg_source,
)
abm.group_id = from_user_name
if (
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
):
at_me = True
if "在群聊中@了你" in d.get("PushContent", ""):
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
# 检查消息是否由自己发送,若是则忽略
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
# if user_id == abm.self_id:
# logger.info("忽略自己发送的消息")
# return None
abm.message = []
# 解析用户真实名字
user_real_name = "unknown"
if abm.group_id:
if (
abm.group_id not in self.userrealnames
or user_id not in self.userrealnames[abm.group_id]
):
# 获取群成员列表,并且缓存
if abm.group_id not in self.userrealnames:
self.userrealnames[abm.group_id] = {}
member_list = await self.get_chatroom_member_list(abm.group_id)
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
if member_list and "memberList" in member_list:
for member in member_list["memberList"]:
self.userrealnames[abm.group_id][member["wxid"]] = member[
"nickName"
]
if user_id in self.userrealnames[abm.group_id]:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
try:
info = (await self.get_user_or_group_info(user_id))["data"][0]
user_real_name = info["nickName"]
except Exception as e:
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
user_real_name = user_id
if at_me:
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
for wxid in at_wxids:
# 群聊里 At 其他人的列表
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
abm.message.append(At(qq=wxid, name=_username))
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
if user_id == "weixin":
# 忽略微信团队消息
return
# 不同消息类型
match d["MsgType"]:
case 1:
# 文本消息
abm.message.append(Plain(content))
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid, content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir, f"gewe_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
case 37: # 好友申请
logger.info("消息类型(37):好友申请")
case 42: # 名片
logger.info("消息类型(42):名片")
case 43: # 视频
video = Video(file="", cover=content)
abm.message.append(video)
case 47: # emoji
data_parser = GeweDataParser(content, abm.group_id == "")
emoji = data_parser.parse_emoji()
abm.message.append(emoji)
case 48: # 地理位置
logger.info("消息类型(48):地理位置")
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
data_parser = GeweDataParser(content, abm.group_id == "")
segments = data_parser.parse_mutil_49()
if segments:
abm.message.extend(segments)
for seg in segments:
if isinstance(seg, Plain):
abm.message_str += seg.text
case 51: # 帐号消息同步?
logger.info("消息类型(51):帐号消息同步?")
case 10000: # 被踢出群聊/更换群主/修改群名称
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
logger.info(
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
)
case _:
logger.info(f"未实现的消息类型: {d['MsgType']}")
abm.raw_message = d
logger.debug(f"abm: {abm}")
return abm
async def _callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
if data.get("testMsg", None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}"
)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def _register_file(self, file_path: str) -> str:
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
Args:
file_path (str): 文件路径。
Returns:
str: 返回一个 auth_token文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
"""
async with self.lock:
if not os.path.exists(file_path):
raise Exception(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
self.staged_files[file_token] = file_path
return file_token
async def _handle_file(self, file_token):
async with self.lock:
if file_token not in self.staged_files:
logger.warning(f"请求的文件 {file_token} 不存在。")
return quart.abort(404)
if not os.path.exists(self.staged_files[file_token]):
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
return quart.abort(404)
file_path = self.staged_files[file_token]
self.staged_files.pop(file_token, None)
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={"token": self.token, "callbackUrl": self.callback_url},
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob["ret"] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
)
async def start_polling(self):
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host="0.0.0.0",
port=self.port,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger(self):
await self.shutdown_event.wait()
async def check_online(self, appid: str):
"""检查 APPID 对应的设备是否在线。"""
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkOnline",
headers=self.headers,
json={"appId": appid},
) as resp:
json_blob = await resp.json()
return json_blob["data"]
async def logout(self):
"""登出 gewechat。"""
if self.appid:
online = await self.check_online(self.appid)
if online:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/logout",
headers=self.headers,
json={"appId": self.appid},
) as resp:
json_blob = await resp.json()
logger.info(f"登出结果: {json_blob}")
async def login(self):
"""登录 gewechat。一般来说插件用不到这个方法。"""
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(
self.base_url, self.download_base_url, self.token
)
if self.appid:
try:
online = await self.check_online(self.appid)
if online:
logger.info(f"APPID: {self.appid} 已在线")
return
except Exception as e:
logger.error(f"检查在线状态失败: {e}")
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
payload = {"appId": self.appid}
if self.appid:
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/getLoginQrCode",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
if json_blob["ret"] != 200:
error_msg = json_blob.get("data", {}).get("msg", "")
if "设备不存在" in error_msg:
logger.error(
f"检测到无效的appid: {self.appid},将清除并重新登录。"
)
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
return await self.login()
else:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob["data"]["qrData"]
qr_uuid = json_blob["data"]["uuid"]
appid = json_blob["data"]["appId"]
logger.info(f"APPID: {appid}")
logger.warning(
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
)
except Exception as e:
raise e
# 执行登录
retry_cnt = 64
payload.update({"uuid": qr_uuid, "appId": appid})
while retry_cnt > 0:
retry_cnt -= 1
# 需要验证码
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
code_file_path = os.path.join(temp_dir, "gewe_code")
if os.path.exists(code_file_path):
with open(code_file_path, "r") as f:
code = f.read().strip()
if not code:
logger.warning(
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
)
await asyncio.sleep(5)
continue
payload["captchCode"] = code
logger.info(f"使用验证码: {code}")
try:
os.remove(code_file_path)
except Exception:
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkLogin",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(f"检查登录状态: {json_blob}")
ret = json_blob["ret"]
msg = ""
if json_blob["data"] and "msg" in json_blob["data"]:
msg = json_blob["data"]["msg"]
if ret == 500 and "安全验证码" in msg:
logger.warning(
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
)
else:
if "status" in json_blob["data"]:
status = json_blob["data"]["status"]
nickname = json_blob["data"].get("nickName", "")
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{self.nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
"""
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
"""获取群成员列表。
Args:
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
Returns:
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
"""
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
return json_blob["data"]
async def post_text(self, to_wxid, content: str, ats: str = ""):
"""发送纯文本消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"content": content,
}
if ats:
payload["ats"] = ats
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postText", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
"""发送图片消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
"""发送emoji消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"emojiMd5": emoji_md5,
"emojiSize": emoji_size,
}
# 优先表情包若拿不到表情包的md5就用当作图片发
try:
if emoji_md5 != "" and emoji_size != "":
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postEmoji",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
)
else:
await self.post_image(to_wxid, cdnurl)
except Exception as e:
logger.error(e)
async def post_video(
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"videoUrl": video_url,
"thumbUrl": thumb_url,
"videoDuration": video_duration,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送视频结果: {json_blob}")
async def forward_video(self, to_wxid, cnd_xml: str):
"""转发视频
Args:
to_wxid (str): 发送给谁
cnd_xml (str): 视频消息的cdn信息
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"xml": cnd_xml,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/forwardVideo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"转发视频结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
"""发送语音信息
Args:
voice_url (str): 语音文件的网络链接
voice_duration (int): 语音时长,毫秒
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration,
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
async def post_file(self, to_wxid, file_url: str, file_name: str):
"""发送文件
Args:
to_wxid (string): 微信ID
file_url (str): 文件的网络链接
file_name (str): 文件名
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"fileUrl": file_url,
"fileName": file_name,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送文件结果: {json_blob}")
async def add_friend(self, v3: str, v4: str, content: str):
"""申请添加好友"""
payload = {
"appId": self.appid,
"scene": 3,
"content": content,
"v4": v4,
"v3": v3,
"option": 2,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/addContacts",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"申请添加好友结果: {json_blob}")
return json_blob
async def get_group(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_group_member(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def accept_group_invite(self, url: str):
"""同意进群"""
payload = {"appId": self.appid, "url": url}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/agreeJoinRoom",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def add_group_member_to_friend(
self, group_id: str, to_wxid: str, content: str
):
payload = {
"appId": self.appid,
"chatroomId": group_id,
"content": content,
"memberWxid": to_wxid,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/addGroupMemberAsFriend",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_user_or_group_info(self, *ids):
"""
获取用户或群组信息。
:param ids: 可变数量的 wxid 参数
"""
wxids_str = list(ids)
payload = {
"appId": self.appid,
"wxids": wxids_str, # 使用逗号分隔的字符串
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/getDetailInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_contacts_list(self):
"""
获取通讯录列表
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
"""
payload = {"appId": self.appid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/fetchContactsList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取通讯录列表结果: {json_blob}")
return json_blob

View File

@@ -1,55 +0,0 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader:
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}", headers=self.headers, json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
"""返回一个可下载的 URL"""
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {"appId": appid, "xml": xml, "type": choice}
data = await self._post_json(
self.base_url, "/message/downloadImage", payload
)
json_blob = json.loads(data)
if "fileUrl" in json_blob["data"]:
return self.download_base_url + json_blob["data"]["fileUrl"]
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")
async def download_emoji_md5(self, app_id, emoji_md5):
"""下载emoji"""
try:
payload = {"appId": app_id, "emojiMd5": emoji_md5}
# gewe 计划中的接口暂时没有实现。返回代码404
data = await self._post_json(
self.base_url, "/message/downloadEmojiMd5", payload
)
json_blob = json.loads(data)
return json_blob
except BaseException as e:
logger.error(f"gewe download emoji: {e}")

View File

@@ -1,264 +0,0 @@
import asyncio
import re
import wave
import uuid
import traceback
import os
from typing import AsyncGenerator
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
from astrbot.api.message_components import (
Plain,
Image,
Record,
At,
File,
Video,
WechatEmoji as Emoji,
)
from .client import SimpleGewechatClient
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
def get_wav_duration(file_path):
with wave.open(file_path, "rb") as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
elif n_frames == 0:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleGewechatClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
):
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
# 检查@
ats = []
ats_names = []
for comp in message.chain:
if isinstance(comp, At):
ats.append(comp.qq)
ats_names.append(comp.name)
has_at = False
for comp in message.chain:
if isinstance(comp, Plain):
text = comp.text
payload = {
"to_wxid": to_wxid,
"content": text,
}
if not has_at and ats:
ats = f"{','.join(ats)}"
ats_names = f"@{' @'.join(ats_names)}"
text = f"{ats_names} {text}"
payload["content"] = text
payload["ats"] = ats
has_at = True
await client.post_text(**payload)
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
token = await client._register_file(img_path)
img_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback img url: {img_url}")
await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
if comp.cover != "":
await client.forward_video(to_wxid, comp.cover)
else:
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
video_url = comp.file
# 根据 url 下载视频
if video_url.startswith("http"):
video_filename = f"{uuid.uuid4()}.mp4"
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
video_path = os.path.join(temp_dir, video_filename)
await download_file(video_url, video_path)
else:
video_path = video_url
video_token = await client._register_file(video_path)
video_callback_url = f"{client.file_server_url}/{video_token}"
# 获取视频第一帧
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
thumb_path = os.path.join(
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
)
video_path = video_path.replace(" ", "\\ ")
try:
ff = FFmpeg()
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
ff.options(command)
thumb_token = await client._register_file(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_token}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
# 发送视频
await client.post_video(
to_wxid, video_callback_url, thumb_url, video_duration
)
# 删除临时缩略图文件
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = await comp.convert_to_file_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
try:
duration = await wav_to_tencent_silk(record_path, silk_path)
except Exception as e:
logger.error(traceback.format_exc())
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
logger.info("Silk 语音文件格式转换至: " + record_path)
if duration == 0:
duration = get_wav_duration(record_path)
token = await client._register_file(silk_path)
record_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback record url: {record_url}")
await client.post_voice(to_wxid, record_url, duration * 1000)
elif isinstance(comp, File):
file_path = comp.file
file_name = comp.name
if file_path.startswith("file:///"):
file_path = file_path[8:]
elif file_path.startswith("http"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
temp_file_path = os.path.join(temp_dir, file_name)
await download_file(file_path, temp_file_path)
file_path = temp_file_path
else:
file_path = file_path
token = await client._register_file(file_path)
file_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback file url: {file_url}")
await client.post_file(to_wxid, file_url, file_name)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At):
pass
else:
logger.debug(f"gewechat 忽略: {comp.type}")
async def send(self, message: MessageChain):
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
await super().send(message)
async def get_group(self, group_id=None, **kwargs):
# 确定有效的 group_id
if group_id is None:
group_id = self.get_group_id()
if not group_id:
return None
res = await self.client.get_group(group_id)
data: dict = res["data"]
if not data["chatroomId"]:
return None
members = [
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
for member in data.get("memberList", [])
]
return Group(
group_id=data["chatroomId"],
group_name=data.get("nickName"),
group_avatar=data.get("smallHeadImgUrl"),
group_owner=data.get("chatRoomOwner"),
members=members,
)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -1,103 +0,0 @@
import sys
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
class GewechatPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
self.client = None
self.client = SimpleGewechatClient(
self.config["base_url"],
self.config["nickname"],
self.config["host"],
self.config["port"],
self._event_queue,
)
async def on_event_received(abm: AstrBotMessage):
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
@override
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
session_id = session.session_id
if "#" in session_id:
# unique session
to_wxid = session_id.split("#")[1]
else:
to_wxid = session_id
await GewechatPlatformEvent.send_with_client(
message_chain, to_wxid, self.client
)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="gewechat",
description="基于 gewechat 的 Wechat 适配器",
id=self.config.get("id"),
)
async def terminate(self):
self.client.shutdown_event.set()
try:
await self.client.server.shutdown()
except Exception as _:
pass
logger.info("Gewechat 适配器已被优雅地关闭。")
async def logout(self):
await self.client.logout()
@override
def run(self):
return self._run()
async def _run(self):
await self.client.login()
await self.client.start_polling()
async def handle_msg(self, message: AstrBotMessage):
if message.type == MessageType.GROUP_MESSAGE:
if self.settingss["unique_session"]:
message.session_id = message.sender.user_id + "#" + message.group_id
message_event = GewechatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
)
self.commit_event(message_event)
def get_client(self) -> SimpleGewechatClient:
return self.client

View File

@@ -1,110 +0,0 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import (
WechatEmoji as Emoji,
Reply,
Plain,
BaseMessageComponent,
)
class GeweDataParser:
def __init__(self, data, is_private_chat):
self.data = data
self.is_private_chat = is_private_chat
def _format_to_xml(self):
return eT.fromstring(self.data)
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
appmsg_type = self._format_to_xml().find(".//appmsg/type")
if appmsg_type is None:
return
match appmsg_type.text:
case "57":
return self.parse_reply()
def parse_emoji(self) -> Emoji | None:
try:
emoji_element = self._format_to_xml().find(".//emoji")
# 提取 md5 和 len 属性
if emoji_element is not None:
md5_value = emoji_element.get("md5")
emoji_size = emoji_element.get("len")
cdnurl = emoji_element.get("cdnurl")
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
except Exception as e:
logger.error(f"gewechat: parse_emoji failed, {e}")
def parse_reply(self) -> list[Reply, Plain] | None:
"""解析引用消息
Returns:
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
"""
try:
replied_id = -1
replied_uid = 0
replied_nickname = ""
replied_content = "" # 被引用者说的内容
content = "" # 引用者说的内容
root = self._format_to_xml()
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 被引用的信息
svrid = refermsg.find("svrid")
fromusr = refermsg.find("fromusr")
displayname = refermsg.find("displayname")
refermsg_content = refermsg.find("content")
if svrid is not None:
replied_id = svrid.text
if fromusr is not None:
replied_uid = fromusr.text
if displayname is not None:
replied_nickname = displayname.text
if refermsg_content is not None:
# 处理引用嵌套,包括嵌套公众号消息
if refermsg_content.text.startswith(
"<msg>"
) or refermsg_content.text.startswith("<?xml"):
try:
logger.debug("gewechat: Reference message is nested")
refer_root = eT.fromstring(refermsg_content.text)
img = refer_root.find("img")
if img is not None:
replied_content = "[图片]"
else:
app_msg = refer_root.find("appmsg")
refermsg_content_title = app_msg.find("title")
logger.debug(
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
)
replied_content = refermsg_content_title.text
except Exception as e:
logger.error(f"gewechat: nested failed, {e}")
# 处理异常情况
replied_content = refermsg_content.text
else:
replied_content = refermsg_content.text
# 提取引用者说的内容
title = root.find(".//appmsg/title")
if title is not None:
content = title.text
reply_seg = Reply(
id=replied_id,
chain=[Plain(replied_content)],
sender_id=replied_uid,
sender_nickname=replied_nickname,
message_str=replied_content,
)
plain_seg = Plain(content)
return [reply_seg, plain_seg]
except Exception as e:
logger.error(f"gewechat: parse_reply failed, {e}")

View File

@@ -3,15 +3,23 @@ import botpy.message
import botpy.types
import botpy.types.message
import asyncio
import base64
import aiofiles
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Record
from botpy import Client
from botpy.http import Route
from astrbot.api import logger
from botpy.types.message import Media
from botpy.types import message
from typing import Optional
import random
import uuid
import os
class QQOfficialMessageEvent(AstrMessageEvent):
@@ -36,6 +44,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
ret = None
try:
async for chain in generator:
source = self.message_obj.raw_message
@@ -85,9 +94,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text,
image_base64,
image_path,
record_file_path
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
if not plain_text and not image_base64 and not image_path:
if not plain_text and not image_base64 and not image_path and not record_file_path:
return
payload = {
@@ -98,6 +108,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
payload["msg_seq"] = random.randint(1, 10000)
ret = None
match type(source):
case botpy.message.GroupMessage:
if image_base64:
@@ -106,6 +118,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # group record msg
media = await self.upload_group_and_c2c_record(
record_file_path, 3, group_openid=source.group_openid
)
payload["media"] = media
payload["msg_type"] = 7
ret = await self.bot.api.post_group_message(
group_openid=source.group_openid, **payload
)
@@ -116,6 +134,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path: # c2c record
media = await self.upload_group_and_c2c_record(
record_file_path, 3, openid = source.author.user_openid
)
payload["media"] = media
payload["msg_type"] = 7
if stream:
ret = await self.post_c2c_message(
openid=source.author.user_openid,
@@ -165,6 +189,59 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
return await self.bot.api._http.request(route, json=payload)
async def upload_group_and_c2c_record(
self,
file_source: str,
file_type: int,
srv_send_msg: bool = False,
**kwargs
) -> Optional[Media]:
"""
上传媒体文件
"""
# 构建基础payload
payload = {
"file_type": file_type,
"srv_send_msg": srv_send_msg
}
# 处理文件数据
if os.path.exists(file_source):
# 读取本地文件
async with aiofiles.open(file_source, 'rb') as f:
file_content = await f.read()
# use base64 encode
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
else:
# 使用URL
payload["url"] = file_source
# 添加接收者信息和确定路由
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
elif "group_openid" in kwargs:
payload["group_openid"] =kwargs["group_openid"]
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
else:
return None
try:
# 使用底层HTTP请求
result = await self.bot.api._http.request(route, json=payload)
if result:
return Media(
file_uuid=result.get("file_uuid"),
file_info=result.get("file_info"),
ttl=result.get("ttl", 0),
file_id=result.get("id", "")
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
return None
async def post_c2c_message(
self,
openid: str,
@@ -191,6 +268,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text = ""
image_base64 = None # only one img supported
image_file_path = None
record_file_path = None
for i in message.chain:
if isinstance(i, Plain):
plain_text += i.text
@@ -206,6 +284,21 @@ class QQOfficialMessageEvent(AstrMessageEvent):
else:
image_base64 = file_to_base64(i.file)
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
record_wav_path = await i.convert_to_file_path() # wav 路径
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
try:
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
if duration > 0:
record_file_path = record_tecent_silk_path
else:
record_file_path = None
logger.error("转换音频格式时出错音频时长不大于0")
except Exception as e:
logger.error(f"处理语音时出错: {e}")
record_file_path = None
else:
logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path
return plain_text, image_base64, image_file_path, record_file_path

View File

@@ -0,0 +1,482 @@
import asyncio
import json
import time
import websockets
from websockets.asyncio.client import connect
from typing import Optional
from aiohttp import ClientSession, ClientTimeout
from websockets.asyncio.client import ClientConnection
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.api.message_components import Plain, Image, At, File, Record
from xml.etree import ElementTree as ET
@register_platform_adapter(
"satori",
"Satori 协议适配器",
)
class SatoriPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.api_base_url = self.config.get(
"satori_api_base_url", "http://localhost:5140/satori/v1"
)
self.token = self.config.get("satori_token", "")
self.endpoint = self.config.get(
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
)
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
self.ws: Optional[ClientConnection] = None
self.session: Optional[ClientSession] = None
self.sequence = 0
self.logins = []
self.running = False
self.heartbeat_task: Optional[asyncio.Task] = None
self.ready_received = False
async def send_by_session(
self, session: MessageSession, message_chain: MessageChain
):
from .satori_event import SatoriPlatformEvent
await SatoriPlatformEvent.send_with_adapter(
self, message_chain, session.session_id
)
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
def _is_websocket_closed(self, ws) -> bool:
"""检查WebSocket连接是否已关闭"""
if not ws:
return True
try:
if hasattr(ws, "closed"):
return ws.closed
elif hasattr(ws, "close_code"):
return ws.close_code is not None
else:
return False
except AttributeError:
return False
async def run(self):
self.running = True
self.session = ClientSession(timeout=ClientTimeout(total=30))
retry_count = 0
max_retries = 10
while self.running:
try:
await self.connect_websocket()
retry_count = 0
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Satori WebSocket 连接关闭: {e}")
retry_count += 1
except Exception as e:
logger.error(f"Satori WebSocket 连接失败: {e}")
retry_count += 1
if not self.running:
break
if retry_count >= max_retries:
logger.error(f"达到最大重试次数 ({max_retries}),停止重试")
break
if not self.auto_reconnect:
break
delay = min(self.reconnect_delay * (2 ** (retry_count - 1)), 60)
await asyncio.sleep(delay)
if self.session:
await self.session.close()
async def connect_websocket(self):
logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}")
logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}")
if not self.endpoint.startswith(("ws://", "wss://")):
logger.error(f"无效的WebSocket URL: {self.endpoint}")
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
try:
websocket = await connect(self.endpoint, additional_headers={})
self.ws = websocket
await asyncio.sleep(0.1)
await self.send_identify()
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
async for message in websocket:
try:
await self.handle_message(message) # type: ignore
except Exception as e:
logger.error(f"Satori 处理消息异常: {e}")
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"Satori WebSocket 连接关闭: {e}")
raise
except Exception as e:
logger.error(f"Satori WebSocket 连接异常: {e}")
raise
finally:
if self.heartbeat_task:
self.heartbeat_task.cancel()
try:
await self.heartbeat_task
except asyncio.CancelledError:
pass
if self.ws:
try:
await self.ws.close()
except Exception as e:
logger.error(f"Satori WebSocket 关闭异常: {e}")
async def send_identify(self):
if not self.ws:
raise Exception("WebSocket连接未建立")
if self._is_websocket_closed(self.ws):
raise Exception("WebSocket连接已关闭")
identify_payload = {
"op": 3, # IDENTIFY
"body": {
"token": str(self.token) if self.token else "", # 字符串
},
}
# 只有在有序列号时才添加sn字段
if self.sequence > 0:
identify_payload["body"]["sn"] = self.sequence
try:
message_str = json.dumps(identify_payload, ensure_ascii=False)
await self.ws.send(message_str)
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"发送 IDENTIFY 信令时连接关闭: {e}")
raise
except Exception as e:
logger.error(f"发送 IDENTIFY 信令失败: {e}")
raise
async def heartbeat_loop(self):
try:
while self.running and self.ws:
await asyncio.sleep(self.heartbeat_interval)
if self.ws and not self._is_websocket_closed(self.ws):
try:
ping_payload = {
"op": 1, # PING
"body": {},
}
await self.ws.send(json.dumps(ping_payload, ensure_ascii=False))
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"Satori WebSocket 连接关闭: {e}")
break
except Exception as e:
logger.error(f"Satori WebSocket 发送心跳失败: {e}")
break
else:
break
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"心跳任务异常: {e}")
async def handle_message(self, message: str):
try:
data = json.loads(message)
op = data.get("op")
body = data.get("body", {})
if op == 4: # READY
self.logins = body.get("logins", [])
self.ready_received = True
# 输出连接成功的bot信息
if self.logins:
for i, login in enumerate(self.logins):
platform = login.get("platform", "")
user = login.get("user", {})
user_id = user.get("id", "")
user_name = user.get("name", "")
logger.info(
f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}"
)
if "sn" in body:
self.sequence = body["sn"]
elif op == 2: # PONG
pass
elif op == 0: # EVENT
await self.handle_event(body)
if "sn" in body:
self.sequence = body["sn"]
elif op == 5: # META
if "sn" in body:
self.sequence = body["sn"]
except json.JSONDecodeError as e:
logger.error(f"解析 WebSocket 消息失败: {e}, 消息内容: {message}")
except Exception as e:
logger.error(f"处理 WebSocket 消息异常: {e}")
async def handle_event(self, event_data: dict):
try:
event_type = event_data.get("type")
sn = event_data.get("sn")
if sn:
self.sequence = sn
if event_type == "message-created":
message = event_data.get("message", {})
user = event_data.get("user", {})
channel = event_data.get("channel", {})
guild = event_data.get("guild")
login = event_data.get("login", {})
timestamp = event_data.get("timestamp")
if user.get("id") == login.get("user", {}).get("id"):
return
abm = await self.convert_satori_message(
message, user, channel, guild, login, timestamp
)
if abm:
await self.handle_msg(abm)
except Exception as e:
logger.error(f"处理事件失败: {e}")
async def convert_satori_message(
self,
message: dict,
user: dict,
channel: dict,
guild: Optional[dict],
login: dict,
timestamp: Optional[int] = None,
) -> Optional[AstrBotMessage]:
try:
abm = AstrBotMessage()
abm.message_id = message.get("id", "")
abm.raw_message = {
"message": message,
"user": user,
"channel": channel,
"guild": guild,
"login": login,
}
if guild and guild.get("id"):
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = guild.get("id", "")
abm.session_id = channel.get("id", "")
else:
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = channel.get("id", "")
abm.sender = MessageMember(
user_id=user.get("id", ""),
nickname=user.get("nick", user.get("name", "")),
)
abm.self_id = login.get("user", {}).get("id", "")
content = message.get("content", "")
abm.message = await self.parse_satori_elements(content)
# parse message_str
abm.message_str = ""
for comp in abm.message:
if isinstance(comp, Plain):
abm.message_str += comp.text
# 优先使用Satori事件中的时间戳
if timestamp is not None:
abm.timestamp = timestamp
else:
abm.timestamp = int(time.time())
return abm
except Exception as e:
logger.error(f"转换 Satori 消息失败: {e}")
return None
async def parse_satori_elements(self, content: str) -> list:
"""解析 Satori 消息元素"""
elements = []
if not content:
return elements
try:
wrapped_content = f"<root>{content}</root>"
root = ET.fromstring(wrapped_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
except Exception as e:
raise e
# 如果没有解析到任何元素,将整个内容当作纯文本
if not elements and content.strip():
elements.append(Plain(text=content))
return elements
async def _parse_xml_node(self, node: ET.Element, elements: list) -> None:
"""递归解析 XML 节点"""
if node.text and node.text.strip():
elements.append(Plain(text=node.text))
for child in node:
tag_name = child.tag.lower()
attrs = child.attrib
if tag_name == "at":
user_id = attrs.get("id") or attrs.get("name", "")
elements.append(At(qq=user_id, name=user_id))
elif tag_name in ("img", "image"):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:image/"):
src = src.split(",")[1]
elements.append(Image.fromBase64(src))
elif src.startswith("http"):
elements.append(Image.fromURL(src))
else:
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
elif tag_name == "file":
src = attrs.get("src", "")
name = attrs.get("name", "文件")
if src:
elements.append(File(file=src, name=name))
elif tag_name in ("audio", "record"):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:audio/"):
src = src.split(",")[1]
elements.append(Record.fromBase64(src))
elif src.startswith("http"):
elements.append(Record.fromURL(src))
else:
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
else:
# 未知标签,递归处理其内容
if child.text and child.text.strip():
elements.append(Plain(text=child.text))
await self._parse_xml_node(child, elements)
# 处理标签后的文本
if child.tail and child.tail.strip():
elements.append(Plain(text=child.tail))
async def handle_msg(self, message: AstrBotMessage):
from .satori_event import SatoriPlatformEvent
message_event = SatoriPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
adapter=self,
)
self.commit_event(message_event)
async def send_http_request(
self,
method: str,
path: str,
data: dict | None = None,
platform: str | None = None,
user_id: str | None = None,
) -> dict:
if not self.session:
raise Exception("HTTP session 未初始化")
headers = {
"Content-Type": "application/json",
}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
if platform and user_id:
headers["satori-platform"] = platform
headers["satori-user-id"] = user_id
elif self.logins:
current_login = self.logins[0]
headers["satori-platform"] = current_login.get("platform", "")
user = current_login.get("user", {})
headers["satori-user-id"] = user.get("id", "") if user else ""
if not path.startswith("/"):
path = "/" + path
# 使用新的API地址配置
url = f"{self.api_base_url.rstrip('/')}{path}"
try:
async with self.session.request(
method, url, json=data, headers=headers
) as response:
if response.status == 200:
result = await response.json()
return result
else:
return {}
except Exception as e:
logger.error(f"Satori HTTP 请求异常: {e}")
return {}
async def terminate(self):
self.running = False
if self.heartbeat_task:
self.heartbeat_task.cancel()
if self.ws:
try:
await self.ws.close()
except Exception as e:
logger.error(f"Satori WebSocket 关闭异常: {e}")
if self.session:
await self.session.close()

View File

@@ -0,0 +1,221 @@
from typing import TYPE_CHECKING
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, At, File, Record
if TYPE_CHECKING:
from .satori_adapter import SatoriPlatformAdapter
class SatoriPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
adapter: "SatoriPlatformAdapter",
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.adapter = adapter
self.platform = None
self.user_id = None
if (
hasattr(message_obj, "raw_message")
and message_obj.raw_message
and isinstance(message_obj.raw_message, dict)
):
login = message_obj.raw_message.get("login", {})
self.platform = login.get("platform")
user = login.get("user", {})
self.user_id = user.get("id") if user else None
@classmethod
async def send_with_adapter(
cls, adapter: "SatoriPlatformAdapter", message: MessageChain, session_id: str
):
try:
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
content = "".join(content_parts)
channel_id = session_id
data = {"channel_id": channel_id, "content": content}
platform = None
user_id = None
if hasattr(adapter, "logins") and adapter.logins:
current_login = adapter.logins[0]
platform = current_login.get("platform", "")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
result = await adapter.send_http_request(
"POST", "/message.create", data, platform, user_id
)
if result:
return result
else:
return None
except Exception as e:
logger.error(f"Satori 消息发送异常: {e}")
return None
async def send(self, message: MessageChain):
platform = getattr(self, "platform", None)
user_id = getattr(self, "user_id", None)
if not platform or not user_id:
if hasattr(self.adapter, "logins") and self.adapter.logins:
current_login = self.adapter.logins[0]
platform = current_login.get("platform", "")
user = current_login.get("user", {})
user_id = user.get("id", "") if user else ""
try:
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
content = "".join(content_parts)
channel_id = self.session_id
data = {"channel_id": channel_id, "content": content}
result = await self.adapter.send_http_request(
"POST", "/message.create", data, platform, user_id
)
if not result:
logger.error("Satori 消息发送失败")
except Exception as e:
logger.error(f"Satori 消息发送异常: {e}")
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
try:
content_parts = []
async for chain in generator:
if isinstance(chain, MessageChain):
if chain.type == "break":
if content_parts:
content = "".join(content_parts)
temp_chain = MessageChain([Plain(text=content)])
await self.send(temp_chain)
content_parts = []
continue
for component in chain.chain:
if isinstance(component, Plain):
content_parts.append(component.text)
elif isinstance(component, Image):
if content_parts:
content = "".join(content_parts)
temp_chain = MessageChain([Plain(text=content)])
await self.send(temp_chain)
content_parts = []
try:
image_base64 = await component.convert_to_base64()
if image_base64:
img_chain = MessageChain(
[
Plain(
text=f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
]
)
await self.send(img_chain)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
else:
content_parts.append(str(component))
if content_parts:
content = "".join(content_parts)
temp_chain = MessageChain([Plain(text=content)])
await self.send(temp_chain)
except Exception as e:
logger.error(f"Satori 流式消息发送异常: {e}")
return await super().send_streaming(generator, use_fallback)

View File

@@ -183,7 +183,6 @@ class TelegramPlatformAdapter(Platform):
return None
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
logger.debug(f"跳过无法注册的命令: {cmd_name}")
return None
# Build description.

View File

@@ -2,7 +2,7 @@ import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from typing import Awaitable, Any, Callable
from astrbot.core.platform import (
Platform,
AstrBotMessage,
@@ -13,7 +13,7 @@ from astrbot.core.platform import (
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
from astrbot import logger
from astrbot.core import web_chat_queue
from .webchat_queue_mgr import webchat_queue_mgr, WebChatQueueMgr
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
@@ -21,14 +21,46 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
self.webchat_queue_mgr = webchat_queue_mgr
self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, conversation_id: str):
"""Listen to a specific conversation queue"""
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(
f"Error processing message from conversation {conversation_id}: {e}"
)
break
async def run(self):
"""Monitor for new conversation queues and start listeners"""
monitored_conversations = set()
while True:
data = await self.queue.get()
await self.callback(data)
# Check for new conversations
current_conversations = set(self.webchat_queue_mgr.queues.keys())
new_conversations = current_conversations - monitored_conversations
# Start listeners for new conversations
for conversation_id in new_conversations:
task = asyncio.create_task(self.listen_to_queue(conversation_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_conversations.add(conversation_id)
logger.debug(f"Started listener for conversation: {conversation_id}")
# Clean up monitored conversations that no longer exist
removed_conversations = monitored_conversations - current_conversations
monitored_conversations -= removed_conversations
await asyncio.sleep(1) # Check for new conversations every second
@register_platform_adapter("webchat", "webchat")
@@ -45,7 +77,7 @@ class WebChatAdapter(Platform):
os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat", description="webchat", id=self.config.get("id")
name="webchat", description="webchat", id="webchat"
)
async def send_by_session(
@@ -105,7 +137,7 @@ class WebChatAdapter(Platform):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
bot = QueueListener(webchat_queue_mgr, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
@@ -119,6 +151,10 @@ class WebChatAdapter(Platform):
session_id=message.session_id,
)
_, _, payload = message.raw_message # type: ignore
message_event.set_extra("selected_provider", payload.get("selected_provider"))
message_event.set_extra("selected_model", payload.get("selected_model"))
self.commit_event(message_event)
async def terminate(self):

View File

@@ -5,8 +5,8 @@ from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import web_chat_back_queue
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
@@ -18,13 +18,18 @@ class WebChatMessageEvent(AstrMessageEvent):
@staticmethod
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
cid = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
if not message:
await web_chat_back_queue.put(
{"type": "end", "data": "", "streaming": False}
{
"type": "end",
"data": "",
"streaming": False,
} # end means this request is finished
)
return ""
cid = session_id.split("!")[-1]
data = ""
for comp in message.chain:
if isinstance(comp, Plain):
@@ -98,27 +103,21 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await web_chat_back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
"cid": self.session_id.split("!")[-1],
}
)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator:
if chain.type == "break" and final_data:
# 分割符
await web_chat_back_queue.put(
{
"type": "end",
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"cid": self.session_id.split("!")[-1],
"cid": cid,
}
)
final_data = ""
@@ -129,10 +128,10 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put(
{
"type": "end",
"type": "complete", # complete means we return the final result
"data": final_data,
"streaming": True,
"cid": self.session_id.split("!")[-1],
"cid": cid,
}
)
await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,33 @@
import asyncio
class WebChatQueueMgr:
def __init__(self) -> None:
self.queues = {}
"""Conversation ID to asyncio.Queue mapping"""
self.back_queues = {}
"""Conversation ID to asyncio.Queue mapping for responses"""
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a queue for the given conversation ID"""
if conversation_id not in self.queues:
self.queues[conversation_id] = asyncio.Queue()
return self.queues[conversation_id]
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a back queue for the given conversation ID"""
if conversation_id not in self.back_queues:
self.back_queues[conversation_id] = asyncio.Queue()
return self.back_queues[conversation_id]
def remove_queues(self, conversation_id: str):
"""Remove queues for the given conversation ID"""
if conversation_id in self.queues:
del self.queues[conversation_id]
if conversation_id in self.back_queues:
del self.back_queues[conversation_id]
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
webchat_queue_mgr = WebChatQueueMgr()

View File

@@ -210,6 +210,16 @@ class WeChatPadProAdapter(Platform):
logger.error(traceback.format_exc())
return False
def _extract_auth_key(self, data):
"""Helper method to extract auth_key from response data."""
if isinstance(data, dict):
auth_keys = data.get("authKeys") # 新接口
if isinstance(auth_keys, list) and auth_keys:
return auth_keys[0]
elif isinstance(data, list) and data: # 旧接口
return data[0]
return None
async def generate_auth_key(self):
"""
生成授权码。
@@ -218,28 +228,26 @@ class WeChatPadProAdapter(Platform):
params = {"key": self.admin_key}
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
self.auth_key = None # Reset auth_key before generating a new one
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
return
response_data = await response.json()
# 修正成功判断条件和授权码提取路径
if response.status == 200 and response_data.get("Code") == 200:
# 授权码在 Data 字段的列表中
if (
response_data.get("Data")
and isinstance(response_data["Data"], list)
and len(response_data["Data"]) > 0
):
self.auth_key = response_data["Data"][0]
logger.info(f"成功获取授权码 {self.auth_key[:8]}...")
if response_data.get("Code") == 200:
if data := response_data.get("Data"):
self.auth_key = self._extract_auth_key(data)
if self.auth_key:
logger.info("成功获取授权码")
else:
logger.error(
f"生成授权码成功但未找到授权码: {response_data}"
)
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
else:
logger.error(
f"生成授权码失败: {response.status}, {response_data}"
)
logger.error(f"生成授权码失败: {response_data}")
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
except Exception as e:

View File

@@ -0,0 +1,47 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import PlatformMessageHistory
class PlatformMessageHistoryManager:
def __init__(self, db_helper: BaseDatabase):
self.db = db_helper
async def insert(
self,
platform_id: str,
user_id: str,
content: list[dict], # TODO: parse from message chain
sender_id: str = None,
sender_name: str = None,
):
"""Insert a new platform message history record."""
await self.db.insert_platform_message_history(
platform_id=platform_id,
user_id=user_id,
content=content,
sender_id=sender_id,
sender_name=sender_name,
)
async def get(
self,
platform_id: str,
user_id: str,
page: int = 1,
page_size: int = 200,
) -> list[PlatformMessageHistory]:
"""Get platform message history for a specific user."""
history = await self.db.get_platform_message_history(
platform_id=platform_id,
user_id=user_id,
page=page,
page_size=page_size,
)
history.reverse()
return history
async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400):
"""Delete platform message history records older than the specified offset."""
await self.db.delete_platform_message_offset(
platform_id=platform_id, user_id=user_id, offset_sec=offset_sec
)

View File

@@ -4,9 +4,11 @@ import json
from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from typing import List, Dict, Type, Any
from astrbot.core.agent.tool import ToolSet
from openai.types.chat.chat_completion import ChatCompletion
from google.genai.types import GenerateContentResponse
from anthropic.types import Message
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
@@ -20,6 +22,7 @@ class ProviderType(enum.Enum):
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
EMBEDDING = "embedding"
RERANK = "rerank"
@dataclass
@@ -29,11 +32,11 @@ class ProviderMetaData:
desc: str = ""
"""提供商适配器描述."""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
cls_type: Type | None = None
default_config_tmpl: dict = None
default_config_tmpl: dict | None = None
"""平台的默认配置模板"""
provider_display_name: str = None
provider_display_name: str | None = None
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@@ -57,7 +60,7 @@ class ToolCallMessageSegment:
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
content: str | None = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
role: str = "assistant"
@@ -97,7 +100,7 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
func_tool: FuncCall | None = None
func_tool: ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致:
@@ -110,6 +113,9 @@ class ProviderRequest:
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
model: str | None = None
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
@@ -201,17 +207,17 @@ class ProviderRequest:
class LLMResponse:
role: str
"""角色, assistant, tool, err"""
result_chain: MessageChain = None
result_chain: MessageChain | None = None
"""返回的消息链"""
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
tools_call_args: List[Dict[str, Any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
tools_call_ids: List[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None
_new_record: Dict[str, Any] | None = None
_completion_text: str = ""
@@ -222,12 +228,12 @@ class LLMResponse:
self,
role: str,
completion_text: str = "",
result_chain: MessageChain = None,
tools_call_args: List[Dict[str, any]] = None,
tools_call_name: List[str] = None,
tools_call_ids: List[str] = None,
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
result_chain: MessageChain | None = None,
tools_call_args: List[Dict[str, Any]] | None = None,
tools_call_name: List[str] | None = None,
tools_call_ids: List[str] | None = None,
raw_completion: ChatCompletion | None = None,
_new_record: Dict[str, Any] | None = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
@@ -290,3 +296,10 @@ class LLMResponse:
}
)
return ret
@dataclass
class RerankResult:
index: int
"""在候选列表中的索引位置"""
relevance_score: float
"""相关性分数"""

View File

@@ -1,32 +1,17 @@
from __future__ import annotations
import json
import textwrap
import os
import asyncio
import logging
from datetime import timedelta
import aiohttp
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from typing import Dict, List, Awaitable
from astrbot import logger
from astrbot.core.utils.log_pipe import LogPipe
from astrbot.core import sp
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.agent.mcp_client import MCPClient
from astrbot.core.agent.tool import ToolSet, FunctionTool
try:
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
@@ -39,175 +24,109 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
@dataclass
class FuncTool:
"""
用于描述一个函数调用工具。
"""
name: str
parameters: Dict
description: str
handler: Awaitable = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def execute(self, **args) -> Any:
"""执行函数调用"""
if self.origin == "local":
if not self.handler:
raise Exception(f"Local function {self.name} has no handler")
return await self.handler(**args)
elif self.origin == "mcp":
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
if ":" in self.name:
# 如果名字是格式为 mcp:server:tool_name提取实际的工具名
actual_tool_name = self.name.split(":")[-1]
return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
return await self.mcp_client.session.call_tool(self.name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
# alias
FuncTool = FunctionTool
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
self.name = None
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = []
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
如果 `url` 参数存在:
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
cfg = _prepare_config(config.copy())
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = mcp_server_config.copy()
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
key_0 = list(cfg["mcpServers"].keys())[0]
cfg = cfg["mcpServers"][key_0]
cfg.pop("active", None) # Remove active flag from config
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
if "url" in cfg:
is_sse = True
try:
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
is_sse = False
if is_sse:
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self._streams_context.__aenter__()
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*streams)
)
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self._streams_context.__aenter__()
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
)
return False, f"HTTP {response.status}: {response.reason}"
else:
server_params = mcp.StdioServerParameters(
**cfg,
)
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
),
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
class FuncCall:
class FunctionToolManager:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_service_queue = asyncio.Queue()
"""用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool:
return len(self.func_list) == 0
def spec_to_func(
self,
name: str,
func_args: list,
desc: str,
handler: Awaitable,
) -> FuncTool:
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
return FuncTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
def add_func(
self,
name: str,
@@ -225,22 +144,14 @@ class FuncCall:
# check if the tool has been added before
self.remove_func(name)
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FuncTool(
self.func_list.append(
self.spec_to_func(
name=name,
parameters=params,
description=desc,
func_args=func_args,
desc=desc,
handler=handler,
)
self.func_list.append(_func)
)
logger.info(f"添加函数调用工具: {name}")
def remove_func(self, name: str) -> None:
@@ -252,13 +163,17 @@ class FuncCall:
self.func_list.pop(i)
break
def get_func(self, name) -> FuncTool:
def get_func(self, name) -> FuncTool | None:
for f in self.func_list:
if f.name == name:
return f
return None
async def _init_mcp_clients(self) -> None:
def get_full_tool_set(self) -> ToolSet:
"""获取完整工具集"""
tool_set = ToolSet(self.func_list.copy())
return tool_set
async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
@@ -300,72 +215,32 @@ class FuncCall:
)
self.mcp_client_event[name] = event
async def mcp_service_selector(self):
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
{"type": "init"} 初始化所有MCP客户端
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
{"type": "terminate"} 终止所有MCP客户端
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
"""
while True:
data = await self.mcp_service_queue.get()
if data["type"] == "init":
if "name" in data:
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(
data["name"], data["cfg"], event
)
)
self.mcp_client_event[data["name"]] = event
else:
await self._init_mcp_clients()
elif data["type"] == "terminate":
if "name" in data:
# await self._terminate_mcp_client(data["name"])
if data["name"] in self.mcp_client_event:
self.mcp_client_event[data["name"]].set()
self.mcp_client_event.pop(data["name"], None)
self.func_list = [
f
for f in self.func_list
if not (
f.origin == "mcp" and f.mcp_server_name == data["name"]
)
]
else:
for name in self.mcp_client_dict.keys():
# await self._terminate_mcp_client(name)
# self.mcp_client_event[name].set()
if name in self.mcp_client_event:
self.mcp_client_event[name].set()
self.mcp_client_event.pop(name, None)
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
async def _init_mcp_client_task_wrapper(
self, name: str, cfg: dict, event: asyncio.Event
self,
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
# tell the caller we are ready
ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
await self._terminate_mcp_client(name)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
if ready_future and not ready_future.done():
ready_future.set_exception(e)
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
try:
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
@@ -375,6 +250,7 @@ class FuncCall:
self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]
# 移除该MCP服务之前的工具如有
@@ -397,16 +273,6 @@ class FuncCall:
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
# 发生错误时确保客户端被清理
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
return
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
@@ -414,9 +280,9 @@ class FuncCall:
try:
# 关闭MCP连接
await self.mcp_client_dict[name].cleanup()
del self.mcp_client_dict[name]
self.mcp_client_dict.pop(name)
except Exception as e:
logger.info(f"清空 MCP 客户端资源 {name}: {e}")
logger.error(f"清空 MCP 客户端资源 {name}: {e}")
# 移除关联的FuncTool
self.func_list = [
f
@@ -425,204 +291,273 @@ class FuncCall:
]
logger.info(f"已关闭 MCP 服务 {name}")
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:
if "url" in config:
success, error_msg = await _quick_test_mcp_connection(config)
if not success:
raise Exception(error_msg)
mcp_client = MCPClient()
try:
logger.debug(f"testing MCP server connection with config: {config}")
await mcp_client.connect_to_server(config, "test")
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
finally:
logger.debug("Cleaning up MCP client after testing connection.")
await mcp_client.cleanup()
return tool_names
async def enable_mcp_server(
self,
name: str,
config: dict,
event: asyncio.Event | None = None,
ready_future: asyncio.Future | None = None,
timeout: int = 30,
) -> None:
"""Enable_mcp_server a new MCP server to the manager and initialize it.
Args:
name (str): The name of the MCP server.
config (dict): Configuration for the MCP server.
event (asyncio.Event): Event to signal when the MCP client is ready.
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
timeout (int): Timeout for the initialization.
Raises:
TimeoutError: If the initialization does not complete within the specified timeout.
Exception: If there is an error during initialization.
"""
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
return
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
finally:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self, name: str | None = None, timeout: float = 10
) -> None:
"""Disable an MCP server by its name.
Args:
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
timeout (int): Timeout.
"""
if name:
if name not in self.mcp_client_event:
return
client = self.mcp_client_dict.get(name)
self.mcp_client_event[name].set()
if not client:
return
client_running_event = client.running_event
try:
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
finally:
self.mcp_client_event.pop(name, None)
self.func_list = [
f
for f in self.func_list
if f.origin != "mcp" or f.mcp_server_name != name
]
else:
running_events = [
client.running_event.wait() for client in self.mcp_client_dict.values()
]
for key, event in self.mcp_client_event.items():
event.set()
# waiting for all clients to finish
try:
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
_l = []
# 处理所有工具包括本地和MCP工具
for f in self.func_list:
if not f.active:
continue
func_ = {
"type": "function",
"function": {
"name": f.name,
# "parameters": f.parameters,
"description": f.description,
},
}
func_["function"]["parameters"] = f.parameters
if not f.parameters.get("properties") and omit_empty_parameter_field:
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True则删除 parameters 字段
del func_["function"]["parameters"]
_l.append(func_)
return _l
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.openai_schema(
omit_empty_parameter_field=omit_empty_parameter_field
)
def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
tools = []
for f in self.func_list:
if not f.active:
continue
# Convert internal format to Anthropic style
tool = {
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": f.parameters.get("properties", {}),
# Keep the required field from the original parameters if it exists
"required": f.parameters.get("required", []),
},
}
tools.append(tool)
return tools
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.anthropic_schema()
def get_func_desc_google_genai_style(self) -> dict:
"""
获得 Google GenAI API 风格的**已经激活**的工具描述
"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.google_schema()
# Gemini API 支持的数据类型和格式
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
def convert_schema(schema: dict) -> dict:
"""转换 schema 为 Gemini API 格式"""
Returns:
如果没找到,会返回 False"""
func_tool = self.get_func(name)
if func_tool is not None:
func_tool.active = False
# 如果 schema 包含 anyOf则只返回 anyOf 字段
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
# 暂时指定默认为null
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties: # 只在有非空属性时添加
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": f.name,
"description": f.description,
**({"parameters": convert_schema(f.parameters)}),
}
for f in self.func_list
if f.active
]
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools", [], scope="global", scope_id="global"
)
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
func_definition = json.dumps(_l, ensure_ascii=False)
prompt = textwrap.dedent(f"""
ROLE:
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
return True
return False
TOOLS:
可用的函数列表:
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
func_tool = self.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
)
{func_definition}
func_tool.active = True
LIMIT:
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
inactivated_llm_tools: list = sp.get(
"inactivated_llm_tools", [], scope="global", scope_id="global"
)
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put(
"inactivated_llm_tools",
inactivated_llm_tools,
scope="global",
scope_id="global",
)
EXAMPLE:
1. `用户提问`:请问一下天气怎么样? `函数调用`[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
return True
return False
用户的提问是:{question}
""")
@property
def mcp_config_path(self):
data_dir = get_astrbot_data_path()
return os.path.join(data_dir, "mcp_server.json")
def load_mcp_config(self):
if not os.path.exists(self.mcp_config_path):
# 配置文件不存在,创建默认配置
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
return DEFAULT_MCP_CONFIG
_c = 0
while _c < 3:
try:
res = await provider.text_chat(prompt, session_id)
if res.find("```") != -1:
res = res[res.find("```json") + 7 : res.rfind("```")]
res = json.loads(res)
break
with open(self.mcp_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
_c += 1
if _c == 3:
raise e
if "The message you submitted was too long" in str(e):
raise e
logger.error(f"加载 MCP 配置失败: {e}")
return DEFAULT_MCP_CONFIG
if "res" in res and not res["res"]:
return "", False
def save_mcp_config(self, config: dict):
try:
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
return True
except Exception as e:
logger.error(f"保存 MCP 配置失败: {e}")
return False
tool_call_result = []
for tool in res:
# 说明有函数调用
func_name = tool["name"]
args = tool["args"]
# 调用函数
func_tool = self.get_func(func_name)
if not func_tool:
raise Exception(f"Request function {func_name} not found.")
async def sync_modelscope_mcp_servers(self, access_token: str) -> None:
"""从 ModelScope 平台同步 MCP 服务器配置"""
base_url = "https://www.modelscope.cn/openapi/v1"
url = f"{base_url}/mcp/servers/operational"
headers = {
"Authorization": f"Bearer {access_token.strip()}",
"Content-Type": "application/json",
}
ret = await func_tool.execute(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
mcp_server_list = data.get("data", {}).get(
"mcp_server_list", []
)
local_mcp_config = self.load_mcp_config()
synced_count = 0
for server in mcp_server_list:
server_name = server["name"]
operational_urls = server.get("operational_urls", [])
if not operational_urls:
continue
url_info = operational_urls[0]
server_url = url_info.get("url")
if not server_url:
continue
# 添加到配置中(同名会覆盖)
local_mcp_config["mcpServers"][server_name] = {
"url": server_url,
"transport": "sse",
"active": True,
"provider": "modelscope",
}
synced_count += 1
if synced_count > 0:
self.save_mcp_config(local_mcp_config)
tasks = []
for server in mcp_server_list:
name = server["name"]
tasks.append(
self.enable_mcp_server(
name=name,
config=local_mcp_config["mcpServers"][name],
)
)
await asyncio.gather(*tasks)
logger.info(
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器"
)
else:
logger.warning("没有找到可用的 ModelScope MCP 服务器")
else:
raise Exception(
f"ModelScope API 请求失败: HTTP {response.status}"
)
except aiohttp.ClientError as e:
raise Exception(f"网络连接错误: {str(e)}")
except Exception as e:
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {str(e)}")
def __str__(self):
return str(self.func_list)
@@ -630,7 +565,6 @@ class FuncCall:
def __repr__(self):
return str(self.func_list)
async def terminate(self):
for name in self.mcp_client_dict.keys():
await self._terminate_mcp_client(name)
logger.debug(f"清理 MCP 客户端 {name} 资源")
# alias
FuncCall = FunctionToolManager

View File

@@ -3,87 +3,32 @@ import traceback
from typing import List
from astrbot.core import logger, sp
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Personality, Provider, STTProvider, TTSProvider
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
from ..persona_mgr import PersonaManager
class ProviderManager:
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
def __init__(
self,
acm: AstrBotConfigManager,
db_helper: BaseDatabase,
persona_mgr: PersonaManager,
):
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
self.providers_config: List = config["provider"]
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
self.persona_configs: list = config.get("persona", [])
self.astrbot_config = config
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get(
"default_personality", "default"
)
self.personas: List[Personality] = []
self.selected_default_persona = None
for persona in self.persona_configs:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
bd_processed = []
mid_processed = ""
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append({
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
})
user_turn = not user_turn
if mood_imitation_dialogs:
if len(mood_imitation_dialogs) % 2 != 0:
logger.error(
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
)
mood_imitation_dialogs = []
user_turn = True
for dialog in mood_imitation_dialogs:
role = "A" if user_turn else "B"
mid_processed += f"{role}: {dialog}\n"
if not user_turn:
mid_processed += "\n"
user_turn = not user_turn
try:
persona = Personality(
**persona,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed=mid_processed,
)
if persona["name"] == self.default_persona_name:
self.selected_default_persona = persona
self.personas.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
if not self.selected_default_persona:
self.selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed="",
)
self.personas.append(self.selected_default_persona)
# 人格相关属性v4.0.0 版本后被废弃,推荐使用 PersonaManager
self.default_persona_name = persona_mgr.default_persona
self.provider_insts: List[Provider] = []
"""加载的 Provider 的实例"""
@@ -91,53 +36,118 @@ class ProviderManager:
"""加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[Provider] = []
self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map = {}
self.inst_map: dict[str, Provider] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
"""默认的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider = None
"""默认的 Speech To Text Provider 实例"""
self.curr_tts_provider_inst: TTSProvider = None
"""默认的 Text To Speech Provider 实例"""
self.curr_provider_inst: Provider | None = None
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
self.curr_stt_provider_inst: STTProvider | None = None
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
self.curr_tts_provider_inst: TTSProvider | None = None
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
self.db_helper = db_helper
# kdb(experimental)
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
@property
def persona_configs(self) -> list:
"""动态获取最新的 persona 配置"""
return self.persona_mgr.persona_v3_config
@property
def personas(self) -> list:
"""动态获取最新的 personas 列表"""
return self.persona_mgr.personas_v3
@property
def selected_default_persona(self):
"""动态获取最新的默认选中 persona。已弃用请使用 context.persona_mgr.get_default_persona_v3()"""
return self.persona_mgr.selected_default_persona_v3
async def set_provider(
self, provider_id: str, provider_type: ProviderType, umo: str = None
self, provider_id: str, provider_type: ProviderType, umo: str | None = None
):
"""设置提供商。
Args:
provider_id (str): 提供商 ID。
provider_type (ProviderType): 提供商类型。
umo (str, optional): 用户会话 ID用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
umo (str, optional): 用户会话 ID用于提供商会话隔离。
Version 4.0.0: 这个版本下已经默认隔离提供商
"""
if provider_id not in self.inst_map:
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
if umo and self.provider_settings["separate_provider"]:
perf = sp.get("session_provider_perf", {})
session_perf = perf.get(umo, {})
session_perf[provider_type.value] = provider_id
perf[umo] = session_perf
sp.put("session_provider_perf", perf)
if umo:
await sp.session_put(
umo,
f"provider_perf_{provider_type.value}",
provider_id,
)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
sp.put("curr_provider_tts", provider_id)
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.SPEECH_TO_TEXT:
sp.put("curr_provider_stt", provider_id)
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
elif provider_type == ProviderType.CHAT_COMPLETION:
sp.put("curr_provider", provider_id)
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(self, provider_type: ProviderType, umo=None):
"""获取正在使用的提供商实例。
Args:
provider_type (ProviderType): 提供商类型。
umo (str, optional): 用户会话 ID用于提供商会话隔离。
Returns:
Provider: 正在使用的提供商实例。
"""
provider = None
if umo:
provider_id = sp.get(
f"provider_perf_{provider_type.value}",
None,
scope="umo",
scope_id=umo,
)
if provider_id:
provider = self.inst_map.get(provider_id)
if not provider:
# default setting
config = self.acm.get_conf(umo)
if provider_type == ProviderType.CHAT_COMPLETION:
provider_id = config["provider_settings"].get("default_provider_id")
provider = self.inst_map.get(provider_id)
if not provider:
provider = self.provider_insts[0] if self.provider_insts else None
elif provider_type == ProviderType.SPEECH_TO_TEXT:
provider_id = config["provider_stt_settings"].get("provider_id")
if not provider_id:
return None
provider = self.inst_map.get(provider_id)
if not provider:
provider = (
self.stt_provider_insts[0] if self.stt_provider_insts else None
)
elif provider_type == ProviderType.TEXT_TO_SPEECH:
provider_id = config["provider_tts_settings"].get("provider_id")
if not provider_id:
return None
provider = self.inst_map.get(provider_id)
if not provider:
provider = (
self.tts_provider_insts[0] if self.tts_provider_insts else None
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
return provider
async def initialize(self):
# 逐个初始化提供商
@@ -145,29 +155,38 @@ class ProviderManager:
await self.load_provider(provider_config)
# 设置默认提供商
self.curr_provider_inst = self.inst_map.get(
self.provider_settings.get("default_provider_id")
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
scope="global",
scope_id="global",
)
selected_stt_provider_id = sp.get(
"curr_provider_stt",
self.provider_stt_settings.get("provider_id"),
scope="global",
scope_id="global",
)
selected_tts_provider_id = sp.get(
"curr_provider_tts",
self.provider_tts_settings.get("provider_id"),
scope="global",
scope_id="global",
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(
self.provider_stt_settings.get("provider_id")
)
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(
self.provider_tts_settings.get("provider_id")
)
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接
asyncio.create_task(
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
)
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
@@ -260,6 +279,10 @@ class ProviderManager:
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
case "vllm_rerank":
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
@@ -343,7 +366,7 @@ class ProviderManager:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
@@ -417,7 +440,7 @@ class ProviderManager:
self.curr_tts_provider_inst = None
if getattr(self.inst_map[provider_id], "terminate", None):
await self.inst_map[provider_id].terminate()
await self.inst_map[provider_id].terminate() # type: ignore
logger.info(
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
@@ -427,6 +450,8 @@ class ProviderManager:
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()
# 清理 MCP Client 连接
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
await provider_inst.terminate() # type: ignore
try:
await self.llm_tools.disable_mcp_server()
except Exception:
logger.error("Error while disabling MCP servers", exc_info=True)

View File

@@ -1,27 +1,24 @@
import abc
from typing import List
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from typing import AsyncGenerator
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
ToolCallsResult,
ProviderType,
RerankResult,
)
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.db.po import Personality
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
begin_dialogs: List[str] = []
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
@dataclass
class ProviderMeta:
id: str
model: str
type: str
provider_type: ProviderType
class AbstractProvider(abc.ABC):
@@ -40,10 +37,14 @@ class AbstractProvider(abc.ABC):
def meta(self) -> ProviderMeta:
"""获取 Provider 的元数据"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
provider_type = meta_data.provider_type if meta_data else None
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
type=self.provider_config["type"],
type=provider_type_name,
provider_type=provider_type,
)
@@ -84,10 +85,11 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
func_tool: ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -112,10 +114,11 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
func_tool: ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
@@ -198,3 +201,17 @@ class EmbeddingProvider(AbstractProvider):
def get_dim(self) -> int:
"""获取向量的维度"""
...
class RerankProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def rerank(
self, query: str, documents: list[str], top_n: int | None = None
) -> list[RerankResult]:
"""获取查询和文档的重排序分数"""
...

View File

@@ -235,6 +235,7 @@ class ProviderAnthropic(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -259,7 +260,7 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
@@ -285,6 +286,7 @@ class ProviderAnthropic(Provider):
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
if contexts is None:
@@ -300,12 +302,16 @@ class ProviderAnthropic(Provider):
# tool calls result
if tool_calls_result:
if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}

View File

@@ -19,6 +19,7 @@ from ..register import register_provider_adapter
TEMP_DIR = Path("data/temp/azure_tts")
TEMP_DIR.mkdir(parents=True, exist_ok=True)
class OTTSProvider:
def __init__(self, config: Dict):
self.skey = config["OTTS_SKEY"]
@@ -70,12 +71,12 @@ class OTTSProvider:
"style": voice_params["style"],
"role": voice_params["role"],
"rate": voice_params["rate"],
"volume": voice_params["volume"]
"volume": voice_params["volume"],
},
headers={
"User-Agent": f"AstrBot/{VERSION}",
"UAK": "AstrBot/AzureTTS"
}
"UAK": "AstrBot/AzureTTS",
},
)
response.raise_for_status()
file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -88,14 +89,19 @@ class OTTSProvider:
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
await asyncio.sleep(0.5 * (attempt + 1))
class AzureNativeProvider(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict):
super().__init__(provider_config, provider_settings)
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
self.subscription_key = provider_config.get(
"azure_tts_subscription_key", ""
).strip()
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
raise ValueError("无效的Azure订阅密钥")
self.region = provider_config.get("azure_tts_region", "eastus").strip()
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
)
self.client = None
self.token = None
self.token_expire = 0
@@ -104,15 +110,17 @@ class AzureNativeProvider(TTSProvider):
"style": provider_config.get("azure_tts_style", "cheerful"),
"role": provider_config.get("azure_tts_role", "Boy"),
"rate": provider_config.get("azure_tts_rate", "1"),
"volume": provider_config.get("azure_tts_volume", "100")
"volume": provider_config.get("azure_tts_volume", "100"),
}
async def __aenter__(self):
self.client = AsyncClient(headers={
self.client = AsyncClient(
headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
})
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
}
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -120,10 +128,11 @@ class AzureNativeProvider(TTSProvider):
await self.client.aclose()
async def _refresh_token(self):
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
token_url = (
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
)
response = await self.client.post(
token_url,
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
)
response.raise_for_status()
self.token = response.text
@@ -150,8 +159,8 @@ class AzureNativeProvider(TTSProvider):
content=ssml,
headers={
"Authorization": f"Bearer {self.token}",
"User-Agent": f"AstrBot/{VERSION}"
}
"User-Agent": f"AstrBot/{VERSION}",
},
)
response.raise_for_status()
file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -160,6 +169,7 @@ class AzureNativeProvider(TTSProvider):
f.write(chunk)
return str(file_path.resolve())
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
class AzureTTSProvider(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict):
@@ -183,7 +193,7 @@ class AzureTTSProvider(TTSProvider):
error_msg = (
f"JSON解析失败请检查格式错误位置{e.lineno}{e.colno}\n"
f"错误详情: {e.msg}\n"
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}"
f"错误上下文: {json_str[max(0, e.pos - 30) : e.pos + 30]}"
)
raise ValueError(error_msg) from e
except KeyError as e:
@@ -202,8 +212,8 @@ class AzureTTSProvider(TTSProvider):
"style": self.provider_config.get("azure_tts_style"),
"role": self.provider_config.get("azure_tts_role"),
"rate": self.provider_config.get("azure_tts_rate"),
"volume": self.provider_config.get("azure_tts_volume")
}
"volume": self.provider_config.get("azure_tts_volume"),
},
)
else:
async with self.provider as provider:

View File

@@ -67,6 +67,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -74,8 +75,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
if (
@@ -163,6 +163,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")

View File

@@ -18,7 +18,7 @@ class ProviderDify(Provider):
self,
provider_config,
provider_settings,
default_persona = None,
default_persona=None,
) -> None:
super().__init__(
provider_config,
@@ -60,12 +60,14 @@ class ProviderDify(Provider):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
image_urls = []
result = ""
session_id = session_id or kwargs.get("user") # 1734
session_id = session_id or kwargs.get("user") or "unknown" # 1734
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
@@ -95,9 +97,9 @@ class ProviderDify(Provider):
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var = await sp.session_get(session_id, "session_variables", default={})
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
try:
match self.api_type:
@@ -197,6 +199,7 @@ class ProviderDify(Provider):
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")

View File

@@ -1,5 +1,6 @@
import os
import uuid
import re
import ormsgpack
from pydantic import BaseModel, conint
from httpx import AsyncClient
@@ -24,8 +25,8 @@ class ServeTTSRequest(BaseModel):
# 参考音频
references: list[ServeReferenceAudio] = []
# 参考模型 ID
# 例如 https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# 其中reference_id为 7f92f8afb8ec43bf81429cc1c9199cb1
# 例如 https://fish.audio/m/626bb6d3f3364c9cbc3aa6a67300a664/
# 其中reference_id为 626bb6d3f3364c9cbc3aa6a67300a664
reference_id: str | None = None
# 对中英文文本进行标准化,这可以提高数字的稳定性
normalize: bool = True
@@ -44,6 +45,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.reference_id: str = provider_config.get("fishaudio-tts-reference-id", "")
self.character: str = provider_config.get("fishaudio-tts-character", "可莉")
self.api_base: str = provider_config.get(
"api_base", "https://api.fish-audio.cn/v1"
@@ -81,11 +83,43 @@ class ProviderFishAudioTTSAPI(TTSProvider):
return item["_id"]
return None
def _validate_reference_id(self, reference_id: str) -> bool:
"""
验证reference_id格式是否有效
Args:
reference_id: 参考模型ID
Returns:
bool: ID是否有效
"""
if not reference_id or not reference_id.strip():
return False
# FishAudio的reference_id通常是32位十六进制字符串
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
pattern = r'^[a-fA-F0-9]{32}$'
return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict:
# 向前兼容逻辑优先使用reference_id如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip():
# 验证reference_id格式
if not self._validate_reference_id(self.reference_id):
raise ValueError(
f"无效的FishAudio参考模型ID: '{self.reference_id}'. "
f"请确保ID是32位十六进制字符串例如: 626bb6d3f3364c9cbc3aa6a67300a664"
f"您可以从 https://fish.audio/zh-CN/discovery 获取有效的模型ID。"
)
reference_id = self.reference_id.strip()
else:
# 回退到原来的角色名称查询逻辑
reference_id = await self._get_reference_id_by_character(self.character)
return ServeTTSRequest(
text=text,
format="wav",
reference_id=await self._get_reference_id_by_character(self.character),
reference_id=reference_id,
)
async def get_audio(self, text: str) -> str:

View File

@@ -14,8 +14,8 @@ import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
@@ -61,7 +61,7 @@ class ProviderGoogleGenAI(Provider):
default_persona,
)
self.api_keys: list = provider_config.get("key", [])
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.timeout: int = int(provider_config.get("timeout", 180))
self.api_base: Optional[str] = provider_config.get("api_base", None)
@@ -96,6 +96,9 @@ class ProviderGoogleGenAI(Provider):
async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
"""处理API错误返回是否需要重试"""
if e.message is None:
e.message = ""
if e.code == 429 or "API key not valid" in e.message:
keys.remove(self.chosen_api_key)
if len(keys) > 0:
@@ -119,7 +122,7 @@ class ProviderGoogleGenAI(Provider):
async def _prepare_query_config(
self,
payloads: dict,
tools: Optional[FuncCall] = None,
tools: Optional[ToolSet] = None,
system_instruction: Optional[str] = None,
modalities: Optional[list[str]] = None,
temperature: float = 0.7,
@@ -259,10 +262,12 @@ class ProviderGoogleGenAI(Provider):
contents.append(content_cls(parts=part))
gemini_contents: list[types.Content] = []
native_tool_enabled = any([
native_tool_enabled = any(
[
self.provider_config.get("gm_native_coderunner", False),
self.provider_config.get("gm_native_search", False),
])
]
)
for message in payloads["messages"]:
role, content = message["role"], message.get("content")
@@ -319,11 +324,15 @@ class ProviderGoogleGenAI(Provider):
@staticmethod
def _process_content_parts(
result: types.GenerateContentResponse, llm_response: LLMResponse
candidate: types.Candidate, llm_response: LLMResponse
) -> MessageChain:
"""处理内容部分并构建消息链"""
finish_reason = result.candidates[0].finish_reason
result_parts: Optional[types.Part] = result.candidates[0].content.parts
if not candidate.content:
logger.warning(f"收到的 candidate.content 为空: {candidate}")
raise Exception("API 返回的 candidate.content 为空。")
finish_reason = candidate.finish_reason
result_parts: list[types.Part] | None = candidate.content.parts
if finish_reason == types.FinishReason.SAFETY:
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
@@ -341,22 +350,28 @@ class ProviderGoogleGenAI(Provider):
raise Exception("模型生成内容违反 Gemini 平台政策")
if not result_parts:
logger.debug(result.candidates)
raise Exception("API 返回的内容为空。")
logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
raise Exception("API 返回的 candidate.content.parts 为空。")
chain = []
part: types.Part
# 暂时这样Fallback
if all(
part.inline_data and part.inline_data.mime_type.startswith("image/")
part.inline_data
and part.inline_data.mime_type
and part.inline_data.mime_type.startswith("image/")
for part in result_parts
):
chain.append(Comp.Plain("这是图片"))
for part in result_parts:
if part.text:
chain.append(Comp.Plain(part.text))
elif part.function_call:
elif (
part.function_call
and part.function_call.name is not None
and part.function_call.args is not None
):
llm_response.role = "tool"
llm_response.tools_call_name.append(part.function_call.name)
llm_response.tools_call_args.append(part.function_call.args)
@@ -364,11 +379,16 @@ class ProviderGoogleGenAI(Provider):
llm_response.tools_call_ids.append(
part.function_call.id or part.function_call.name
)
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
elif (
part.inline_data
and part.inline_data.mime_type
and part.inline_data.mime_type.startswith("image/")
and part.inline_data.data
):
chain.append(Comp.Image.fromBytes(part.inline_data.data))
return MessageChain(chain=chain)
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
@@ -394,6 +414,10 @@ class ProviderGoogleGenAI(Provider):
config=config,
)
if not result.candidates:
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
raise Exception("请求失败, 返回的 candidates 为空。")
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
if temperature > 2:
raise Exception("温度参数已超过最大值2仍然发生recitation")
@@ -406,6 +430,8 @@ class ProviderGoogleGenAI(Provider):
break
except APIError as e:
if e.message is None:
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
@@ -429,11 +455,14 @@ class ProviderGoogleGenAI(Provider):
continue
llm_response = LLMResponse("assistant")
llm_response.result_chain = self._process_content_parts(result, llm_response)
llm_response.raw_completion = result
llm_response.result_chain = self._process_content_parts(
result.candidates[0], llm_response
)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
self, payloads: dict, tools: ToolSet | None
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
@@ -456,6 +485,8 @@ class ProviderGoogleGenAI(Provider):
)
break
except APIError as e:
if e.message is None:
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt已自动去除(影响人格设置)"
@@ -468,34 +499,61 @@ class ProviderGoogleGenAI(Provider):
raise
continue
# Accumulate the complete response text for the final response
accumulated_text = ""
final_response = None
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
if not chunk.candidates:
logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}")
continue
if not chunk.candidates[0].content:
logger.warning(f"收到的 chunk 中 content 为空: {chunk}")
continue
if chunk.candidates[0].content.parts and any(
part.function_call for part in chunk.candidates[0].content.parts
):
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.raw_completion = chunk
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
chunk.candidates[0], llm_response
)
yield llm_response
break
return
if chunk.text:
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response
if chunk.candidates[0].finish_reason:
llm_response = LLMResponse("assistant", is_chunk=False)
if not chunk.candidates[0].content.parts:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
else:
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
# Process the final chunk for potential tool calls or other content
if chunk.candidates[0].content.parts:
final_response = LLMResponse("assistant", is_chunk=False)
final_response.raw_completion = chunk
final_response.result_chain = self._process_content_parts(
chunk.candidates[0], final_response
)
yield llm_response
break
# Yield final complete response with accumulated text
if not final_response:
final_response = LLMResponse("assistant", is_chunk=False)
# Set the complete accumulated text in the final response
if accumulated_text:
final_response.result_chain = MessageChain(
chain=[Comp.Plain(accumulated_text)]
)
elif not final_response.result_chain:
# If no text was accumulated and no final response was set, provide empty space
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
yield final_response
async def text_chat(
self,
prompt: str,
@@ -505,6 +563,7 @@ class ProviderGoogleGenAI(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -527,7 +586,7 @@ class ProviderGoogleGenAI(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -542,15 +601,18 @@ class ProviderGoogleGenAI(Provider):
continue
break
raise Exception("请求失败。")
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: str = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
prompt,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
@@ -566,10 +628,14 @@ class ProviderGoogleGenAI(Provider):
# tool calls result
if tool_calls_result:
if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -592,7 +658,9 @@ class ProviderGoogleGenAI(Provider):
return [
m.name.replace("models/", "")
for m in models
if "generateContent" in m.supported_actions
if m.supported_actions
and "generateContent" in m.supported_actions
and m.name
]
except APIError as e:
raise Exception(f"获取模型列表失败: {e.message}")
@@ -607,7 +675,7 @@ class ProviderGoogleGenAI(Provider):
self.chosen_api_key = key
self._init_client()
async def assemble_context(self, text: str, image_urls: list[str] = None):
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
"""
组装上下文。
"""
@@ -628,10 +696,12 @@ class ProviderGoogleGenAI(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append({
user_content["content"].append(
{
"type": "image_url",
"image_url": {"url": image_data},
})
}
)
return user_content
else:
return {"role": "user", "content": text}

View File

@@ -22,7 +22,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
timeout=int(provider_config.get("timeout", 20)),
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
self.dimension = provider_config.get("embedding_dimensions", 1536)
self.dimension = provider_config.get("embedding_dimensions", 1024)
async def get_embedding(self, text: str) -> list[float]:
"""

View File

@@ -30,7 +30,7 @@ class ProviderOpenAIOfficial(Provider):
self,
provider_config,
provider_settings,
default_persona = None,
default_persona=None,
) -> None:
super().__init__(
provider_config,
@@ -99,6 +99,15 @@ class ProviderOpenAIOfficial(Provider):
for key in to_del:
del payloads[key]
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
# 针对 deepseek 模型的特殊处理deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
if model == "deepseek-reasoner" and "tools" in payloads:
del payloads["tools"]
completion = await self.client.chat.completions.create(
**payloads, stream=False, extra_body=extra_body
)
@@ -129,6 +138,12 @@ class ProviderOpenAIOfficial(Provider):
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
to_del = []
for key in payloads.keys():
if key not in self.default_params:
@@ -176,7 +191,7 @@ class ProviderOpenAIOfficial(Provider):
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
if choice.message.content:
if choice.message.content is not None:
# text completion
completion_text = str(choice.message.content).strip()
llm_response.result_chain = MessageChain().message(completion_text)
@@ -187,6 +202,9 @@ class ProviderOpenAIOfficial(Provider):
func_name_ls = []
tool_call_ids = []
for tool_call in choice.message.tool_calls:
if isinstance(tool_call, str):
# workaround for #1359
tool_call = json.loads(tool_call)
for tool in tools.func_list:
if tool.name == tool_call.function.name:
# workaround for #1454
@@ -207,7 +225,7 @@ class ProviderOpenAIOfficial(Provider):
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
)
if not llm_response.completion_text and not llm_response.tools_call_args:
if llm_response.completion_text is None and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
@@ -222,6 +240,7 @@ class ProviderOpenAIOfficial(Provider):
contexts: list | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -245,7 +264,7 @@ class ProviderOpenAIOfficial(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -346,6 +365,7 @@ class ProviderOpenAIOfficial(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -354,6 +374,7 @@ class ProviderOpenAIOfficial(Provider):
contexts,
system_prompt,
tool_calls_result,
model=model,
**kwargs,
)
@@ -413,6 +434,7 @@ class ProviderOpenAIOfficial(Provider):
contexts=[],
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
@@ -422,6 +444,7 @@ class ProviderOpenAIOfficial(Provider):
contexts,
system_prompt,
tool_calls_result,
model=model,
**kwargs,
)
@@ -477,13 +500,8 @@ class ProviderOpenAIOfficial(Provider):
"""
new_contexts = []
flag = False
for context in contexts:
if flag:
flag = False # 删除 image 后下一条LLM 响应)也要删除
continue
if isinstance(context["content"], list):
flag = True
if "content" in context and isinstance(context["content"], list):
# continue
new_content = []
for item in context["content"]:
@@ -526,7 +544,10 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
{
"type": "image_url",
"image_url": {"url": image_data},
}
)
return user_content
else:

View File

@@ -0,0 +1,65 @@
import aiohttp
from astrbot import logger
from ..provider import RerankProvider
from ..register import register_provider_adapter
from ..entities import ProviderType, RerankResult
@register_provider_adapter(
"vllm_rerank",
"VLLM Rerank 适配器",
provider_type=ProviderType.RERANK,
)
class VLLMRerankProvider(RerankProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
self.auth_key = provider_config.get("rerank_api_key", "")
self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
self.base_url = self.base_url.rstrip("/")
self.timeout = provider_config.get("timeout", 20)
self.model = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
h = {}
if self.auth_key:
h["Authorization"] = f"Bearer {self.auth_key}"
self.client = aiohttp.ClientSession(
headers=h,
timeout=aiohttp.ClientTimeout(total=self.timeout),
)
async def rerank(
self, query: str, documents: list[str], top_n: int | None = None
) -> list[RerankResult]:
payload = {
"query": query,
"documents": documents,
"model": self.model,
}
if top_n is not None:
payload["top_n"] = top_n
async with self.client.post(
f"{self.base_url}/v1/rerank", json=payload
) as response:
response_data = await response.json()
results = response_data.get("results", [])
if not results:
logger.warning(
f"Rerank API 返回了空的列表数据。原始响应: {response_data}"
)
return [
RerankResult(
index=result["index"],
relevance_score=result["relevance_score"],
)
for result in results
]
async def terminate(self) -> None:
"""关闭客户端会话"""
if self.client:
await self.client.close()
self.client = None

View File

@@ -5,12 +5,12 @@ import os
import traceback
import asyncio
import aiohttp
import requests
from ..provider import TTSProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot import logger
@register_provider_adapter(
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
)
@@ -22,7 +22,9 @@ class ProviderVolcengineTTS(TTSProvider):
self.cluster = provider_config.get("volcengine_cluster", "")
self.voice_type = provider_config.get("volcengine_voice_type", "")
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
self.api_base = provider_config.get(
"api_base", "https://openspeech.bytedance.com/api/v1/tts"
)
self.timeout = provider_config.get("timeout", 20)
def _build_request_payload(self, text: str) -> dict:
@@ -30,11 +32,9 @@ class ProviderVolcengineTTS(TTSProvider):
"app": {
"appid": self.appid,
"token": self.api_key,
"cluster": self.cluster
},
"user": {
"uid": str(uuid.uuid4())
"cluster": self.cluster,
},
"user": {"uid": str(uuid.uuid4())},
"audio": {
"voice_type": self.voice_type,
"encoding": "mp3",
@@ -48,15 +48,15 @@ class ProviderVolcengineTTS(TTSProvider):
"text_type": "plain",
"operation": "query",
"with_frontend": 1,
"frontend_type": "unitTson"
}
"frontend_type": "unitTson",
},
}
async def get_audio(self, text: str) -> str:
"""异步方法获取语音文件路径"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer; {self.api_key}"
"Authorization": f"Bearer; {self.api_key}",
}
payload = self._build_request_payload(text)
@@ -71,7 +71,7 @@ class ProviderVolcengineTTS(TTSProvider):
self.api_base,
data=json.dumps(payload),
headers=headers,
timeout=self.timeout
timeout=self.timeout,
) as response:
logger.debug(f"响应状态码: {response.status}")
@@ -90,8 +90,7 @@ class ProviderVolcengineTTS(TTSProvider):
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: open(file_path, "wb").write(audio_data)
None, lambda: open(file_path, "wb").write(audio_data)
)
return file_path
@@ -99,7 +98,9 @@ class ProviderVolcengineTTS(TTSProvider):
error_msg = resp_data.get("message", "未知错误")
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
else:
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
raise Exception(
f"火山引擎 TTS API 请求失败: {response.status}, {response_text}"
)
except Exception as e:
error_details = traceback.format_exc()

View File

@@ -28,6 +28,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -38,7 +39,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
model = self.get_model()
model = model or self.get_model()
# glm-4v-flash 只支持一张图片
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")

View File

@@ -1,4 +1,4 @@
from .star import StarMetadata
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
from .context import Context
from astrbot.core.provider import Provider
@@ -10,25 +10,50 @@ from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin):
"""所有插件Star的父类所有插件都应该继承于这个类"""
def __init__(self, context: Context):
def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context)
self.context = context
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
async def text_to_image(self, text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(text, return_url=return_url)
return await html_renderer.render_t2i(
text,
return_url=return_url,
template_name=self.context._config.get("t2i_active_template"),
)
async def html_render(
self, tmpl: str, data: dict, return_url=True, options: dict = None
self, tmpl: str, data: dict, return_url=True, options: dict | None = None
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl, data, return_url=return_url, options=options
)
async def initialize(self):
"""当插件被激活时会调用这个方法"""
pass
async def terminate(self):
"""当插件被禁用、重载插件时会调用这个方法"""
pass
def __del__(self):
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
pass
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]

Some files were not shown because too many files have changed in this diff Show More