Compare commits

...

300 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
c6b6eef8c4 Complete Docker compatibility fix with enhanced documentation
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:03:54 +00:00
copilot-swe-agent[bot]
50cf263076 Implement CLI Docker compatibility fix and login-info command
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:01:36 +00:00
copilot-swe-agent[bot]
2554548088 Initial commit: fix formatting and explore codebase
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 15:53:37 +00:00
copilot-swe-agent[bot]
aa4a2d10e2 Initial plan 2025-08-16 15:48:10 +00:00
Soulter
02a9769b35 fix: 补充工具调用轮数上限配置 2025-08-14 23:51:22 +08:00
Soulter
7640f11bfc docs: update readme 2025-08-14 17:28:51 +08:00
Soulter
9fa44dbcfa docs: update readme 2025-08-14 14:16:22 +08:00
Copilot
2cae941bae Fix incomplete Gemini streaming responses in chat history (#2429)
* Initial plan

* Fix incomplete Gemini streaming responses in chat history

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-08-14 11:56:50 +08:00
Copilot
bc0784f41d fix: enable_thinking parameter for qwen3 models in non-streaming calls (#2424)
* Initial plan

* Fix ModelScope enable_thinking parameter for non-streaming calls

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Tighten enable_thinking condition to only Qwen/Qwen3 models

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* qwen3 model handle

* Update astrbot/core/provider/sources/openai_source.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-14 11:18:29 +08:00
Copilot
c57d75e01a feat: add comprehensive GitHub Copilot instructions for AstrBot development (#2426)
* Initial plan

* Initial progress - completed repository exploration and dependency installation

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Complete copilot-instructions.md with comprehensive development guide

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Update copilot-instructions.md

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-13 23:31:28 +08:00
RC-CHN
73edeae013 perf: 优化hint渲染方式,为部分类型供应商添加默认的温度选项 (#2321)
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑

* perf:优化hint渲染方式,为部分类型供应商添加默认的温度选项
2025-08-12 21:53:06 +08:00
MUKAPP
7d46314dc8 fix: 修复注册文件时由于 file:/// 前缀,导致文件被误判为不存在的问题 (#2325)
fixes #2222
2025-08-12 21:47:31 +08:00
你们的饺子
d5a53a89eb fix: 修复插件的 terminate 无法被正常调用的问题 (#2352) 2025-08-12 21:41:19 +08:00
dependabot[bot]
a85bc510dd chore(deps): bump actions/checkout in the github-actions group (#2400)
Bumps the github-actions group with 1 update: [actions/checkout](https://github.com/actions/checkout).


Updates `actions/checkout` from 4 to 5
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-12 15:15:28 +08:00
Soulter
2beea7d218 📦 release: v3.5.24 2025-08-07 20:36:59 +08:00
Soulter
a93cd3dd5f feat: compshare provider 2025-08-07 20:25:45 +08:00
Soulter
db4d02c2e2 docs: add 1panel deployment method 2025-08-02 19:01:49 +08:00
Soulter
fd7811402b fix: 添加对 metadata 中 description 字段的支持,确保元数据完整性
fixes: #2245
2025-08-02 16:01:10 +08:00
你们的饺子
eb0325e627 fix: 修复了 OpenAI 类型的 LLM 空内容响应导致的无法解析 completion 的错误。 (#2279) 2025-08-02 15:46:11 +08:00
IGCrystal
8b4b04ec09 fix(i18n): add missing noTemplates key (#2292) 2025-08-02 14:16:59 +08:00
Larch-C
9f32c9280f chore: update and rename PLUGIN_PUBLISH.md to PLUGIN_PUBLISH.yml (#2289) 2025-08-02 14:16:19 +08:00
yrk111222
4fcd09cfa8 feat: add ModelScope API support (#2230)
* add ModelScope API support

* update
2025-08-02 14:14:08 +08:00
Misaka Mikoto
7a8d65d37d feat: add plugins local cache and remote file MD5 validation (#2211)
* 修改openai的嵌入模型默认维度为1024

* 为插件市场添加本地缓存
- 优先使用api获取,获取失败时则使用本地缓存
- 每次获取后会更新本地缓存
- 如果获取结果为空,判定为获取失败,使用本地缓存
- 前端页面添加刷新按钮,用于手动刷新本地缓存

* feat: 增强插件市场缓存机制,支持MD5校验以确保数据有效性

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-08-02 14:03:53 +08:00
Raven95676
23129a9ba2 Merge branch 'releases/3.5.23' 2025-07-26 16:49:38 +08:00
Raven95676
7f791e730b fix: changelogs 2025-07-26 16:49:05 +08:00
Raven95676
f7e296b349 Merge branch 'releases/3.5.23' 2025-07-26 16:34:30 +08:00
Raven95676
712d4acaaa release: v3.5.23 2025-07-26 16:32:06 +08:00
Raven95676
74a5c01f21 refactor: remove code and documentation references related to gewechat 2025-07-26 14:19:17 +08:00
Raven95676
3ba8724d77 Merge branch 'master' into dev 2025-07-26 14:02:05 +08:00
鸦羽
6313a7d8a9 Merge pull request #2221 from Raven95676/fix/axios-dependency
fix: update axios version range for vulnerability fix
2025-07-24 18:34:14 +08:00
Raven95676
432a3f520c fix: update axios version range for vulnerability fix 2025-07-24 18:28:02 +08:00
Soulter
191b3e42d4 feat: implement log history retrieval and improve log streaming handling (#2190) 2025-07-23 23:36:08 +08:00
Misaka Mikoto
a27f05fcb4 chore: 修改 OpenAI 嵌入模型提供商默认向量维度为1024 (#2209) 2025-07-23 23:35:04 +08:00
Soulter
2f33e0b873 chore: remove adapters of wechat personal account 2025-07-23 10:51:42 +08:00
Soulter
f0359467f1 chore: remove adapters of wechat personal account 2025-07-23 10:50:43 +08:00
Soulter
d1db8cf2c8 chore: remove adapters of wechat personal account 2025-07-23 10:48:58 +08:00
Soulter
b1985ed2ce Merge branch 'dev' 2025-07-23 00:38:08 +08:00
Gao Jinzhe
140ddc70e6 feat: 使用会话锁保证分段回复时的消息发送顺序 (#2130)
* 优化分段消息发送逻辑,为分段消息添加消息队列

* 删除了不必要的代码

* style: code quality

* 将消息队列机制重构为会话锁机制

* perf: narrow the lock scope

* refactor: replace get_lock with async context manager for session locks

* refactor: optimize session lock management with defaultdict

---------

Co-authored-by: Soulter <905617992@qq.com>
Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-07-23 00:37:29 +08:00
Soulter
d7fd616470 style: code quality 2025-07-21 17:04:29 +08:00
Soulter
3ccbef141e perf: extension ui 2025-07-21 15:16:49 +08:00
Soulter
e92fbb0443 feat: add ProxySelector component for GitHub proxy configuration and connection testing (#2185) 2025-07-21 15:05:49 +08:00
Soulter
bd270aed68 fix: handle event construction errors in message reply processing 2025-07-20 22:52:14 +08:00
Soulter
28d7864393 perf: tool use page UI (#2182)
* perf: tool use UI

* fix: update background color of item cards in ToolUsePage
2025-07-20 20:24:03 +08:00
RC-CHN
b5d8173ee3 feat: add a file uplod button in WebChat page (#2136)
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑
2025-07-20 16:02:28 +08:00
Soulter
17d62a9af7 refactor: mcp server reload mechanism (#2161)
* refactor: mcp server reload mechanism

* fix: wait for client events

* fix: all other mcp servers are terminated when disable selected server

* fix: resolve type hinting issues in MCPClient and FuncCall methods

* perf: optimize mcp server loaders

* perf: improve MCP client connection testing

* perf: improve error message

* perf: clean code

* perf: increase default timeout for MCP connection and reset dialog message on close

---------

Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-07-20 15:53:13 +08:00
Soulter
d89fb863ed fix: improve logging and error message details in LLMRequestSubStage 2025-07-18 16:13:27 +08:00
Soulter
a21ad77820 Merge pull request #2146 from Raven95676/fix/mcp
fix: 修复MCP导致的持续占用100% CPU
2025-07-18 13:04:50 +08:00
Raven95676
f86c8e8cab perf: ensure MCP client termination in cleanup process 2025-07-17 23:17:23 +08:00
Raven95676
cb12cbdd3d fix: managing MCP connections with AsyncExitStack 2025-07-16 23:44:51 +08:00
Soulter
6661fa996c fix: audio block does not display 2025-07-14 22:20:03 +08:00
Soulter
c19bca798b fix: xfyun model tool use error workaround
fixes: #1359
2025-07-14 22:07:33 +08:00
Soulter
8f98b411db Merge pull request #2129 from AstrBotDevs/perf-refine-webui-chatpage
Improve: WebUI ChatPage markdown code block background
2025-07-14 21:49:18 +08:00
Soulter
a8aa03847e feat: enhance theme customization with new background properties and markdown styling 2025-07-14 21:47:25 +08:00
Soulter
1bfd747cc6 perf: add system_prompt to payload_vars in dify text_chat method 2025-07-14 11:00:13 +08:00
Soulter
ae06d945a7 Merge pull request #2054 from RC-CHN/master
Feature: Add provider_type field for ProviderMetadata and improve provider availabiliby test
2025-07-13 17:38:22 +08:00
Soulter
9f41d5f34d Merge remote-tracking branch 'origin/master' into RC-CHN/master 2025-07-13 17:35:53 +08:00
Soulter
ef61c52908 fix: remove non-existent Response field 2025-07-13 17:33:13 +08:00
Soulter
d8842ef274 perf: code quality 2025-07-13 17:27:40 +08:00
Soulter
c88fdaf353 Merge pull request #1949 from advent259141/Astrbot_session_manage
[Feature] 支持在 WebUI 上管理会话
2025-07-13 17:23:52 +08:00
Soulter
af295da871 chore: remove /mcp command 2025-07-13 17:11:02 +08:00
Soulter
083235a2fe feat: enhance session management page with tooltips and layout adjustments 2025-07-13 17:06:15 +08:00
Soulter
2a3a5f7eb2 perf: refine session management page ui 2025-07-13 16:57:36 +08:00
Soulter
77c48f280f fix: session management paginator error 2025-07-13 16:36:25 +08:00
Soulter
0ee1eb2f9f chore: remove useless file 2025-07-13 16:30:00 +08:00
Soulter
c2b20365bb Merge pull request #2097 from SheffeyG/fix-status-checking
Fix: add status checking for embedding model providers
2025-07-13 16:19:37 +08:00
Soulter
cfdc7e4452 fix: add debug logging for provider request handling in LLMRequestSubStage
fixes: #2104
2025-07-13 16:12:48 +08:00
Soulter
2363f61aa9 chore: remove 'obvious_hint' fields from configuration metadata and remove some deprecated config 2025-07-13 16:03:47 +08:00
Soulter
557ac6f9fa Merge pull request #2112 from AstrBotDevs/perf-provider-logo
Improve: WebUI provider logo display
2025-07-13 15:34:19 +08:00
Soulter
a49b871cf9 fix: update Azure provider icon URL in getProviderIcon method 2025-07-13 15:33:47 +08:00
Soulter
a0d6b3efba perf: improve provider logo display in webui 2025-07-13 15:27:53 +08:00
Gao Jinzhe
6cabf07bc0 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-07-13 00:23:29 +08:00
advent259141
a15444ee8c 移除了mcp会话级的启停,增加了批量设置的选项,对相关问题进行了修复 2025-07-13 00:15:21 +08:00
Soulter
ceb5f5669e fix: update active reply bot prefix in logging for clarity 2025-07-13 00:12:31 +08:00
Gao Jinzhe
25b75e05e4 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-07-12 22:25:20 +08:00
sheffey
4d214bb5c1 check general numbers type instead 2025-07-11 18:36:47 +08:00
sheffey
7cbaed8c6c fix: add status checking for embedding model providers 2025-07-11 18:36:40 +08:00
Soulter
2915fdf665 release: v3.5.22 2025-07-11 12:29:26 +08:00
Soulter
a66c385b08 fix: deadlock when docker is not available 2025-07-11 12:27:49 +08:00
Raven95676
4dace7c5d8 chore: format code 2025-07-11 11:23:53 +08:00
Soulter
8ebf087dbf chore: optimize codes 2025-07-10 23:28:00 +08:00
Soulter
2fa8bda5bb chore: ruff lint 2025-07-10 23:23:29 +08:00
Soulter
a5ae833945 📦 release: v3.5.21 2025-07-10 17:46:36 +08:00
Soulter
d21d42b312 chore: update icon URL for 302.AI to use color version 2025-07-10 17:44:11 +08:00
Soulter
78575f0f0a fix: failed to delete conversation in webchat
fixes: #2071
2025-07-10 17:04:34 +08:00
Soulter
8ccd292d16 Merge pull request #2082 from AstrBotDevs/fix-webchat-segment-reply
fix: 修复 WebChat 下可能消息错位的问题
2025-07-10 17:00:14 +08:00
Soulter
2534f59398 chore: remove debug print statement from chat route 2025-07-10 16:59:58 +08:00
Soulter
5c60dbe2b1 fix: 修复 WebChat 下可能消息错位的问题 2025-07-10 16:52:16 +08:00
Soulter
c99ecde15f Merge pull request #2078 from AstrBotDevs/fix-webchat-image-cannot-render
Fix: webchat cannot render image and audio image normally
2025-07-10 11:57:50 +08:00
Soulter
219f3403d9 fix: webchat cannot render image and audio image normally 2025-07-10 11:51:47 +08:00
Soulter
00f417bad6 Merge pull request #2073 from Raven95676/fix/register_star
fix: 提升兼容性,并尽可能避免数据竞争
2025-07-10 11:03:57 +08:00
Soulter
81649f053b perf: improve log 2025-07-10 10:58:56 +08:00
Raven95676
e5bde50f2d fix: 提升兼容性,并尽可能避免数据竞争 2025-07-09 22:39:30 +08:00
Raven95676
0321e00b0d perf: 移除nh3 2025-07-09 20:32:14 +08:00
Soulter
09528e3292 docs: add model providers 2025-07-09 14:18:59 +08:00
Soulter
e7412a9cbf docs: add model providers 2025-07-09 14:17:39 +08:00
Soulter
01efe5f869 📦 release: v3.5.20 2025-07-09 13:35:44 +08:00
Soulter
28a178a55c Merge pull request #2067 from AstrBotDevs/refactor-aiocqhttp-send-message
Fix: active message cannot handle forward type message properly in aiocqhttp adapter
2025-07-09 13:23:08 +08:00
Soulter
88f130014c perf: streamline message dispatching logic in AiocqhttpMessageEvent 2025-07-09 12:10:18 +08:00
Soulter
af258c590c Merge pull request #2068 from AstrBotDevs/fix-tool-call-result-wrongly-sent
Fix: 修复工具调用被错误地发出到了消息平台上
2025-07-09 12:02:07 +08:00
Soulter
b0eb5733be Merge pull request #2065 from AstrBotDevs/fix-plugin-metadata-load
Improve: add fallback for missing 'desc' in plugin metadata
2025-07-09 12:01:06 +08:00
Soulter
fe35bfba37 Merge pull request #2064 from uersula/fix-image-removal-flag-logic
Fix: 移除 _remove_image_from_context中的flag逻辑
2025-07-09 12:00:30 +08:00
advent259141
7cfbc4ab8f 增加了针对整个会话启停的开关 2025-07-09 11:58:52 +08:00
Soulter
7a9d4f0abd fix: 修复工具调用被错误地发出到了消息平台上
fixes: #2060
2025-07-09 11:43:25 +08:00
Soulter
6f6a5b565c fix: active message cannot handle forward type message properly in aiocqhttp adapter 2025-07-09 11:19:32 +08:00
Soulter
e57deb873c perf: add fallback for missing 'desc' in plugin metadata and improve error logging 2025-07-09 10:47:03 +08:00
Gao Jinzhe
0f692b1608 Merge branch 'master' into Astrbot_session_manage 2025-07-09 10:13:51 +08:00
uersula
8c03e79f99 Fix: Remove buggy flag logic in _remove_image_from_context 2025-07-08 23:01:11 +08:00
Soulter
71290f0929 Merge pull request #2061 from AstrBotDevs/feat-handle-image-in-quote-message
Feature: 支持对引用消息中的图片内容进行理解
2025-07-08 22:11:17 +08:00
Soulter
22364ef7de feat: 支持对引用消息中的图片内容进行理解
fixes: #2056
2025-07-08 22:08:40 +08:00
Ruochen
2cc1eb1abc feat:实现了speech_to_text类型的供应商可用性检查 2025-07-08 21:55:31 +08:00
RC-CHN
90dbcbb4e2 Merge branch 'AstrBotDevs:master' into master 2025-07-08 21:28:50 +08:00
Ruochen
66503d58be feat:实现了text_to_speech类型的供应商可用性测试 2025-07-08 17:52:22 +08:00
Ruochen
8e10f0ce2b feat:实现了embedding类型的供应商可用性检查 2025-07-08 16:51:57 +08:00
Soulter
f51f510f2e perf: enhance date handle in reminder
fixes: #1901
2025-07-08 16:33:46 +08:00
Ruochen
c44f085b47 fix:对非文本生成类供应商暂时跳过测试 2025-07-08 16:32:39 +08:00
RC-CHN
a35f36eeaf Merge branch 'AstrBotDevs:master' into master 2025-07-08 15:34:19 +08:00
Ruochen
14564c392a feat:meta方法增加provider_type字段 2025-07-08 15:33:02 +08:00
Soulter
76e05ea749 Merge pull request #2022 from AstrBotDevs/deprecate/register_star-decorator
[Deprecation] 弃用register_star装饰器
2025-07-08 11:57:28 +08:00
Soulter
ab599dceed Merge branch 'master' into deprecate/register_star-decorator 2025-07-08 11:52:33 +08:00
Soulter
4c37604445 perf: only output deprecation warning once for @register_star decorator 2025-07-08 11:50:55 +08:00
Soulter
bb74018d19 Merge pull request #1998 from diudiu62/feat-wechatpadpro-adapter
增加监听wechatpadpro消息平台的事件
2025-07-08 11:40:13 +08:00
Soulter
575289e5bc feat: complete platform adapter types and update mapping 2025-07-08 11:39:42 +08:00
Soulter
e89da2a7b4 Merge pull request #2035 from cclauss/patch-1
pytest recommendation: `pip install --editable .`
2025-07-08 11:35:34 +08:00
Soulter
bd34959f68 📦 release: v3.5.19 2025-07-08 01:34:08 +08:00
Soulter
622dcf8fd5 fix: 通过指令选择提供商重启后失效 2025-07-08 01:24:19 +08:00
Soulter
9e315739b7 Merge pull request #2051 from AstrBotDevs/perf-ui
Improve: 改善 WebUI 效果
2025-07-08 00:35:52 +08:00
Soulter
7b01adc5df perf: better webui 2025-07-08 00:33:22 +08:00
Soulter
432fc47443 feat: add 302.ai llm provider 2025-07-07 23:01:28 +08:00
Soulter
d8fba44c5e Merge pull request #2049 from uersula/fix/keyerror-in-recovery-handler
Fix: 防止错误恢复机制_remove_image_from_context发生KeyError
2025-07-07 22:13:43 +08:00
Soulter
e29d3d8c01 Merge pull request #2043 from Zhenyi-Wang/master
fix(wechatpadpro): 修复授权码提取逻辑以兼容新旧接口格式
2025-07-07 22:10:20 +08:00
uersula
e678413214 Fix: Prevent KeyError in _remove_image_from_context 2025-07-07 02:30:50 +08:00
Soulter
eaa9d9d087 Merge pull request #2027 from IGCrystal/Branch-2
🐞 fix(WebUI): 解决XSS注入的问题
2025-07-06 18:13:40 +08:00
Soulter
9e3cc076b7 🐞 fix(ReadmeDialog): add variant attribute to close button for consistency 2025-07-06 18:13:00 +08:00
IGCrystal
3bb01fa52c feat(ChatPage): 添加图像预览 2025-07-06 18:08:17 +08:00
IGCrystal
008e49d144 🎈 perf: 优化音频附件的显示 2025-07-06 18:08:17 +08:00
IGCrystal
4e275384b0 🐞 fix(VerticalHeader): 允许HTML渲染 2025-07-06 18:08:17 +08:00
IGCrystal
63ec99f67a 🐞 fix: 添加不存在的翻译键 2025-07-06 18:08:17 +08:00
IGCrystal
14a8bb57df 🐞 fix(WebUI): 解决XSS注入的问题 2025-07-06 18:08:17 +08:00
Soulter
7512bfc710 fix: update user message bubble styling for improved appearance 2025-07-06 18:06:28 +08:00
Soulter
3c3b6dadc3 Merge pull request #2037 from AstrBotDevs/fix/tool_call_result
fix: direct send tool_call_result
2025-07-06 18:05:59 +08:00
Soulter
cd722a0e39 fix: handle direct tool call results 2025-07-06 18:04:46 +08:00
Soulter
a1b5d0a100 Merge remote-tracking branch 'origin/master' into fix/tool_call_result 2025-07-06 17:47:09 +08:00
Raven95676
69d3ae709c fix: direct send tool_call_result 2025-07-06 17:45:07 +08:00
Soulter
67ef993d61 fix: webchat message bubble style 2025-07-06 17:21:57 +08:00
Soulter
20f49890ad fix: provider selection for updating webchat title 2025-07-06 17:18:37 +08:00
Zhenyi Wang
3e4917f0a1 refactor: 重构 wechatpadpro 授权码生成并增强安全性
- 将 generate_auth_key 方法中的授权码提取逻辑重构为新的辅助方法 _extract_auth_key ,以提高代码的可读性和可测试性。
- 在访问 data.get('authKeys') 之前添加 isinstance(data, dict) 检查,以防止潜在的 AttributeError 。
- 移除了 auth_key 的明文日志记录,以避免敏感信息泄露。
- 在生成新密钥之前,将 self.auth_key 初始化为 None ,以避免在失败时保留旧值。
2025-07-06 16:34:55 +08:00
Soulter
99ee75aec6 Merge pull request #2029 from jiongjiongJOJO/master
fix: 增加演示模式下校验插件开启/关闭/安装指令
2025-07-06 16:24:02 +08:00
Zhenyi Wang
1674653a42 fix(wechatpadpro): 修复授权码提取逻辑以兼容新旧接口格式
新接口返回多了一层authKeys字段,同时兼容二者
2025-07-06 16:18:31 +08:00
Christian Clauss
d2f7e55bf5 Run the tests on pull requests 2025-07-05 13:57:58 +02:00
Christian Clauss
9f31df7f3a pytest recommendation: pip install --editable .
https://docs.pytest.org/en/stable/how-to/existingtestsuite.html

This makes setting `PYTHONPATH` unnecessary and will pull requirements from `pyproject.toml` instead of `requirements.txt`, so it is similar to end-user installations.

`makedir -p data/plugins` will do both `mkdir data` and `mkdir data/plugins`.

The `$CI` environment variable might be better to use than `$TESTING` because it is preset to `true` in GitHub Actions.
* https://docs.github.com/en/actions/reference/variables-reference#default-environment-variables
* https://docs.pytest.org/en/stable/explanation/ci.html
2025-07-05 13:52:28 +02:00
Soulter
b8c1b53d67 Merge pull request #2034 from AstrBotDevs/dependabot/github_actions/github-actions-50e66c4123
chore(deps): bump the github-actions group with 4 updates
2025-07-05 19:24:16 +08:00
dependabot[bot]
2495837791 chore(deps): bump the github-actions group with 4 updates
Bumps the github-actions group with 4 updates: [actions/checkout](https://github.com/actions/checkout), [actions/setup-python](https://github.com/actions/setup-python), [codecov/codecov-action](https://github.com/codecov/codecov-action) and [actions/stale](https://github.com/actions/stale).


Updates `actions/checkout` from 3 to 4
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

Updates `actions/setup-python` from 4 to 5
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v4...v5)

Updates `codecov/codecov-action` from 4 to 5
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/codecov/codecov-action/compare/v4...v5)

Updates `actions/stale` from 5 to 9
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v9)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: codecov/codecov-action
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/stale
  dependency-version: '9'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-07-05 11:20:25 +00:00
Soulter
b6562e3c47 Merge pull request #2005 from cclauss/patch-1
Keep GitHub Actions up to date with GitHub's Dependabot
2025-07-05 19:19:18 +08:00
Soulter
c57da046ee Merge pull request #2013 from AstrBotDevs/feat/danger-plugin
[Copilot] feat: 添加风险插件安装确认对话框以及风险插件标签特殊处理
2025-07-05 19:18:14 +08:00
JOJO
ff63134c14 fix: 增加演示模式下校验插件开启/关闭/安装指令 2025-07-05 12:43:19 +08:00
鸦羽
3f5210c587 chore: update plugin publish template 2025-07-04 22:28:00 +08:00
IGCrystal
3df5e7b9b9 🐞 fix: 添加tags.danger的翻译键 2025-07-04 17:28:39 +08:00
Soulter
225db66738 fix: refine streaming logic in chat response handling 2025-07-04 16:59:49 +08:00
Soulter
383ebb8f57 feat: add copy functionality for bot messages with success feedback 2025-07-04 16:27:52 +08:00
Raven95676
e1bed60f1f fix: adjust timing of adding to star_registry 2025-07-04 16:13:10 +08:00
Raven95676
edbb856023 refactor: deprecate register_star decorator 2025-07-04 15:54:23 +08:00
Raven95676
98d3ab646f chore: convert some methods to static 2025-07-04 15:07:14 +08:00
Soulter
81be556f1b Merge pull request #2018 from AstrBotDevs/fix-extension-btn-z-index
Fix: adjust z-index for the add button on ExtensionPage
2025-07-04 11:41:10 +08:00
Soulter
f45a085469 fix: adjust z-index for the add button on ExtensionPage
fixes: #1985
2025-07-04 11:40:14 +08:00
Raven95676
210cc58cc3 fix: 更新风险插件警告对话框内容和按钮文本,修正样式 By @Soulter
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-07-04 11:23:19 +08:00
Soulter
1063b11ef6 fix: check provider availability errors on dify 2025-07-04 10:19:58 +08:00
Raven95676
a4e999c47f feat: 添加风险插件安装确认对话框以及风险插件标签特殊处理 2025-07-03 22:16:00 +08:00
Soulter
543e01c301 perf: webui 删除对话使用 conversation_mgr,以保持状态同步 2025-07-03 15:44:45 +08:00
Soulter
14e0aa3ec5 perf: history 和 persona 指令当对话不存在的时候自动创建
fixes: #1997
2025-07-03 15:40:00 +08:00
Christian Clauss
1a8a171f8b Keep GitHub Actions up to date with GitHub's Dependabot
* [Keeping your software supply chain secure with Dependabot](https://docs.github.com/en/code-security/dependabot)
* [Keeping your actions up to date with Dependabot](https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot)
* [Configuration options for the `dependabot.yml` file - package-ecosystem](https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem)
2025-07-03 08:46:42 +02:00
Soulter
f1954f9a43 Merge pull request #1984 from RC-CHN/master
refactor:将前端测试供应商部分修改为独立并发异步获取各个文本供应商的状态
2025-07-03 10:55:51 +08:00
Soulter
441b148501 Merge pull request #1991 from AstrBotDevs/perf/webchat-title
perf: 优化WebChat对话标题生成
2025-07-03 10:53:35 +08:00
Soulter
bd0f30b81c Merge pull request #2003 from AstrBotDevs/feat-webchat-select-provider
Feature: WebChat 增加可选择提供商和模型的功能
2025-07-03 10:52:42 +08:00
Soulter
ad14e9bf40 chore: remove unnecessary logging of payloads in chat completion 2025-07-03 10:50:03 +08:00
Soulter
6f71301aaf fix: log error when selected provider is not found 2025-07-03 10:49:12 +08:00
Soulter
5f0d601baa feat: add support for selecting provider and models in webchat 2025-07-03 10:42:20 +08:00
Soulter
f234a5bcc2 fix: enhance event hook handling to return status and prevent propagation 2025-07-03 00:23:56 +08:00
chenpeng
ab677ea100 修正pilk依赖提示文案
增加监听wechatpadpro消息平台的事件
2025-07-02 17:30:37 +08:00
Soulter
f3ad53e949 feat: add supports for selecting provider and models in webchat 2025-07-02 17:12:30 +08:00
Soulter
d324cfa84d Merge pull request #1987 from AstrBotDevs/refactor-webchat-streaming
Refactor: 重构 WebChat 的 SSE 监听逻辑
2025-07-02 17:11:12 +08:00
Soulter
dd4319d72a Merge pull request #1990 from AstrBotDevs/fix-stream-multi-tool-use-err
fix: Multi-turn tools use error when using streaming output
2025-07-02 15:44:29 +08:00
Raven95676
1f2de3d3d8 perf: 优化WebChat对话标题生成 2025-07-02 10:43:54 +08:00
Raven95676
72702beb0b chore: clean code 2025-07-02 10:29:10 +08:00
Soulter
adb0cbc5dd fix: handle tool_calls_result as list or single object in context query in streaming mode 2025-07-02 10:16:44 +08:00
Soulter
6a503b82c3 refactor: web chat queue management and streamline chat route handling 2025-07-01 22:34:17 +08:00
advent259141
28a87351f1 新增对会话重命名的功能 2025-07-01 21:41:19 +08:00
Soulter
bcc97378b0 feat: implement code copy functionality and enhance code highlighting in ChatPage 2025-07-01 21:15:01 +08:00
Soulter
eb8a138713 feat: enhance conversation actions with delete functionality and improved styling 2025-07-01 21:00:43 +08:00
advent259141
dcd7dcbbdf 解决了conflict 2025-07-01 17:24:56 +08:00
Gao Jinzhe
1538759ba7 Merge branch 'master' into Astrbot_session_manage 2025-07-01 17:19:30 +08:00
Soulter
30e8ea7fd8 chore: add deploy badge 2025-07-01 16:59:58 +08:00
Ruochen
879b7b582c perf:提取重复的错误处理逻辑,优化循环调用 2025-07-01 16:02:56 +08:00
Ruochen
8ba4236402 refactor:将前端测试供应商部分修改为独立并发异步获取各个文本供应商的状态 2025-07-01 15:41:30 +08:00
鸦羽
5eef8fa9b9 Merge pull request #1981 from AstrBotDevs/feat/r1_filter-integration
feat: 集成r1_filter至框架
2025-07-01 13:56:01 +08:00
Raven95676
d03d035437 perf: 合并嵌套的if条件 2025-07-01 13:53:22 +08:00
Raven95676
68e8e1f70b feat: 集成r1_filter至框架 2025-07-01 12:40:52 +08:00
Soulter
7acb45b157 Update README.md 2025-07-01 11:35:14 +08:00
Soulter
c36142deaf perf: chatpage UI 2025-06-30 15:20:46 +08:00
Soulter
5fd6e316fa Merge pull request #1966 from railgun19457/master
修改了一对大括号
2025-06-30 13:33:10 +08:00
railgun19457
39a9d7765a 修改了一对大括号 2025-06-30 00:21:28 +08:00
Soulter
7cfcba29a6 feat: add loading state for dashboard update process 2025-06-29 21:55:13 +08:00
Soulter
9bf8aadca9 📦 release: v3.5.18 2025-06-29 21:52:45 +08:00
Soulter
714d4af63d Merge pull request #1963 from AstrBotDevs/refactor-llm-request
Refactor: 将 LLM Request 部分抽象为 AgentRunner 并优化多轮工具调用
2025-06-29 21:38:43 +08:00
Soulter
8203fdb4f0 fix: webchat show tool call 2025-06-29 21:35:39 +08:00
Soulter
5e1e2d1a4f perf: 优化 ChatPage UI 2025-06-29 21:19:52 +08:00
Soulter
2f941de65b feat: 支持展示工具使用过程 2025-06-29 21:19:40 +08:00
Raven95676
777c503002 perf: change logging level to debug for agent state transitions and LLM responses 2025-06-29 17:32:53 +08:00
Raven95676
e9b23f68fd perf: add AgentState Enum for improved state management 2025-06-29 17:19:53 +08:00
Soulter
efa45e6203 fix: validate and repair message contexts in LLMRequestSubStage 2025-06-29 16:36:08 +08:00
Raven95676
638f55f83c Merge branch 'refactor-llm-request' of https://github.com/AstrBotDevs/AstrBot into refactor-llm-request 2025-06-29 16:13:18 +08:00
Raven95676
8b2fc29d5b chore: remove accidentally committed file 2025-06-29 16:13:15 +08:00
Soulter
b516fb0550 chore: remove dump_plugins.py 2025-06-29 16:12:40 +08:00
Raven95676
efef34c01e style: format code 2025-06-29 16:06:44 +08:00
Soulter
5f1dfa7599 fix: handle LLM response and execute event hook in ToolLoopAgent 2025-06-29 15:58:22 +08:00
Soulter
8e9c7544cf fix: update type check for async generator in PipelineContext 2025-06-29 15:54:32 +08:00
Soulter
4e3d5641c8 chore: code quality 2025-06-29 15:51:56 +08:00
Soulter
20b760529e fix: anthropic api error when using tools 2025-06-29 15:33:08 +08:00
Soulter
a55a07c5ff remove: useless provider init params 2025-06-29 14:43:36 +08:00
Soulter
94ee8ea297 feat: 支持多轮次工具调用并且存储到数据库
移除了 llm tuner 适配器
2025-06-29 14:27:00 +08:00
advent259141
ec5d71d0e1 修复了一下重复的代码问题,删除了不必要的会话级别 LLM 启停状态检查。 2025-06-29 10:02:04 +08:00
advent259141
d121d08d05 大致凭借自己理解修复了一下整个检查流程,防止钩子出现问题 2025-06-29 09:57:31 +08:00
Gao Jinzhe
be08f4a558 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-29 09:11:25 +08:00
Soulter
010f082fbb Merge pull request #1914 from HakimYu/master
fix(AiocqhttpAdapter): 修复at_info.get("nick", "")的错误
2025-06-28 21:52:01 +08:00
Soulter
073cdf6d51 perf: also consider nick 2025-06-28 21:51:10 +08:00
Soulter
4df8606ab6 style: code quality 2025-06-28 20:08:57 +08:00
Soulter
71442d26ec chore: 移除不必要的 MCP 会话控制 2025-06-28 19:58:36 +08:00
advent259141
4f5528869c Merge branch 'Astrbot_session_manage' of https://github.com/advent259141/AstrBot into Astrbot_session_manage 2025-06-28 17:00:00 +08:00
advent259141
f16feff17b 根据会话mcp开关情况选择性传入 func_tool
修改import的位置
deleted:    astrbot/core/star/session_tts_manager.py
复原被覆盖的修改
2025-06-28 16:59:00 +08:00
Soulter
71b233fe5f Merge pull request #1942 from QiChenSn/fix-CommandFilter-ParseForBool
fix:修复commandfilter对布尔类型的解析
2025-06-28 15:10:29 +08:00
Soulter
770dec9ed6 fix: handle boolean parameter parsing correctly in CommandFilter 2025-06-28 15:08:19 +08:00
Soulter
2ca95a988e fix: lint warnings 2025-06-28 15:05:57 +08:00
Gao Jinzhe
d8aae538cd Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-28 14:55:38 +08:00
Soulter
cf1e7ee08a Merge pull request #1947 from RC-CHN/master
允许为html_render方法传递参数
2025-06-28 14:52:09 +08:00
Soulter
d14513ddfd fix: lint warnings 2025-06-28 14:51:35 +08:00
Soulter
9a9017bc6c perf: use union oper for merging dict
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-28 14:46:29 +08:00
Soulter
3c9b654713 Merge pull request #1923 from Magstic/patch-1
Fix: 仪表盘的『插件配置』中不显示 JSON 编辑窗
2025-06-28 14:45:14 +08:00
Magstic
80d2ad40bc fix: 仪表盘的『插件配置』中不显示 JSON 编辑窗
该提交与 #1919 关联。

精准定位错误 @Pine-Ln,Fix from Gemini 2.5 Pro.

这个问题是由两个错误叠加造成的:

1. **组件崩溃**:`AstrBotConfig.vue` 混用了 Vue 3 的 `<script setup>` 和旧式 `<script>` 写法,导致作用域冲突,模板无法访问国际化函数 `t`,引发 `ReferenceError: t is not defined`。

2. **设置项不显示**:原代码根据用户已保存的设置数据来渲染字段,导致新增的设置项(如 `editor_mode`)因为用户配置中没有初始值而不显示。

1. **统一 API 写法**:将整个组件重构为纯 `<script setup>` 写法,解决作用域冲突。

2. **修正渲染逻辑**:将 `v-for` 循环改为遍历设置蓝图 (metadata) 而不是用户数据,确保所有定义的设置项都能显示。
2025-06-28 14:42:06 +08:00
advent259141
31670e75e5 Merge branch 'Astrbot_session_manage' of https://github.com/advent259141/AstrBot into Astrbot_session_manage 2025-06-27 18:47:25 +08:00
advent259141
ed6011a2be modified: dashboard/src/i18n/loader.ts
modified:   dashboard/src/i18n/locales/en-US/core/navigation.json
增加会话管理英文页面
	modified:   dashboard/src/i18n/locales/zh-CN/core/navigation.json
增加会话管理中文页面
	modified:   dashboard/src/i18n/translations.ts
	modified:   dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
	modified:   dashboard/src/views/SessionManagementPage.vue
增加会话管理国际化适配
2025-06-27 18:46:02 +08:00
Gao Jinzhe
cdded38ade Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-27 17:10:08 +08:00
advent259141
f536f24833 astrbot/core/pipeline/process_stage/method/llm_request.py
astrbot/core/pipeline/result_decorate/stage.py
   astrbot/core/star/session_llm_manager.py
   astrbot/core/star/session_tts_manager.py
   astrbot/dashboard/routes/session_management.py
   astrbot/dashboard/server.py
   dashboard/src/views/SessionManagementPage.vue
   packages/astrbot/main.py
2025-06-27 17:08:05 +08:00
Ruochen
f5bff00b1f Merge branch 'master' of https://github.com/RC-CHN/AstrBot 2025-06-27 17:03:58 +08:00
Ruochen
27c9717445 feat:允许html_render方法传入配置参数 2025-06-27 17:03:26 +08:00
Soulter
863a1ba8ef Merge pull request #1922 from SXP-Simon/master
[feat] (discord_platform_adapter) 增加了对机器人 Role Mention 方法的响应,并且修复了控制面板上 Discord 平台无法优雅重载的 Bug
2025-06-27 14:59:37 +08:00
Soulter
cb04dd2b83 chore: remove unnecessary codes 2025-06-27 14:59:08 +08:00
Soulter
8c7cf51958 chore: code format 2025-06-27 14:46:23 +08:00
Soulter
244fb1fed6 chore: remove useless logger 2025-06-27 14:38:31 +08:00
Soulter
25f7a68a13 Merge pull request #1709 from shuiping233/fix-qq-offical-session-bug
fix: qq_official适配器使用SessionController(会话控制)功能时机器人回复消息无法发送到聊天平台
2025-06-27 14:35:54 +08:00
Soulter
62d8cf79ef fix: remove deprecated pre_send and post_send calls for specific platforms 2025-06-27 14:31:35 +08:00
Gao Jinzhe
646b18d910 Merge branch 'AstrBotDevs:master' into master 2025-06-27 12:26:15 +08:00
QiChenSn
2f81b2e381 fix:修复commandfilter对布尔类型的解析 2025-06-27 02:32:10 +08:00
Soulter
1f5a7e7885 Merge pull request #1940 from AstrBotDevs/fix-tg-active-reply
fix: cannot make active reply in telegram
2025-06-27 00:05:10 +08:00
Soulter
80fca470f2 fix: cannot make active reply in telegram
Co-authored-by: youtiaoguagua <cloudcranesss@210625568+cloudcranesss@users.noreply.github.com>
2025-06-27 00:04:25 +08:00
Soulter
6e9d9ac856 Merge pull request #1907 from IGCrystal/Branch-2
🐞 fix(WebUI): 修复安装插件按钮不可见
2025-06-26 23:28:37 +08:00
Soulter
8d6fada1eb feat(ExtensionPage): show confirm dialog when click install plugin button 2025-06-26 23:25:59 +08:00
Soulter
3e715399a1 fix: 环境变量代理被忽略 (#1895) 2025-06-26 08:52:33 +08:00
Soulter
81cc8831f9 docs: update plugin issue template
docs: issue template

docs: update issue template

docs: update plugin issue template

fix: issue plugin template

docs: update plugin issue template
2025-06-26 08:28:28 +08:00
Soulter
f7370044a7 Merge pull request #1903 from IGCrystal/branch-1
 feat: 对PlatformPage使用翻译键
2025-06-25 22:49:03 +08:00
Soulter
51b015a629 Merge pull request #1830 from zhx8702/feat-wechat-tts-mp3towav
feat: wechatpadpro 触发tts时 添加对mp3格式音频支持
2025-06-25 22:46:10 +08:00
Soulter
392af7a553 fix: add pydub to requirements 2025-06-25 22:31:44 +08:00
鸦羽
d2dd07bad7 Merge pull request #1920 from AstrBotDevs/feat/gemini-tts
feat: 增加Gemini TTS API实现
2025-06-25 14:05:04 +08:00
回归天空
cebcd6925a [fix] (discord_platform_adapter) 解决了 “Discord 平台无法优雅重载” 的 bug
#### 问题现象(AI总结)

- 在通过 Web 面板或配置变更热重载 Discord 平台时,适配器的 terminate() 方法会被调用,但经常出现“卡死”或长时间无响应,导致 Discord 平台无法优雅重载。

- 日志显示停留在“正在清理已注册的斜杠指令...”等步骤,甚至出现超时或异常。

#### 2. 原因分析

- 适配器的 terminate() 方法中,涉及多个异步操作(如取消 polling 任务、清理斜杠指令、关闭客户端)。

- 某些 await 操作(如 await self.client.sync_commands() 或 await self.client.close())在网络异常、事件循环被取消等情况下,可能会阻塞或抛出 CancelledError,导致整个重载流程卡住。

- 之前的实现没有对这些 await 操作加超时保护,也没有分步日志,难以定位具体卡点。

#### 3. 修复措施

- 分步日志:在 terminate() 的每个关键步骤前后都加了详细日志,便于定位卡点。

- 超时保护:对所有关键 await 操作(如 polling 任务取消、指令清理、客户端关闭)都加了 asyncio.wait_for(..., timeout=10),防止无限阻塞。

- 健壮性提升:先 cancel polling 任务,再清理指令,最后关闭客户端。每一步都捕获异常并输出日志,保证即使某一步失败也能继续后续清理。

- 避免重复终止:移除了 run() 方法中的 finally: await self.terminate(),只允许外部统一调度,防止重复调用导致资源冲突或日志重复。

#### 4. 修复效果

- 现在 Discord 平台适配器在热重载或终止时,能优雅地依次完成所有清理步骤,不会因某一步阻塞导致整个流程卡死。
2025-06-25 11:46:49 +08:00
回归天空
e7b4357fc7 [feat] (discord_platform_adapter) 增加了对机器人 Role Mention 方法的响应 2025-06-25 11:41:55 +08:00
Raven95676
dc279dde4a fix: 简化get_audio方法中的提示文本生成逻辑,清除冗余判断逻辑 2025-06-25 10:55:51 +08:00
Raven95676
c0810a674f feat: 增加Gemini TTS API实现 2025-06-25 10:50:04 +08:00
HakimYu
0760cabbbe feat(AiocqhttpAdapter): 修复reply类型的 Event.from_payload报错 2025-06-24 17:20:30 +08:00
HakimYu
3b149c520b fix(AiocqhttpAdapter): 修复at_info.get("nick", "")的错误,并在message_str中针对At类型添加QQ号 2025-06-24 16:30:23 +08:00
Soulter
3d19fc89ff docs: 10k star banner 2025-06-24 02:07:23 +08:00
Soulter
cd1b1919f4 docs: 10k star banner 2025-06-24 01:51:46 +08:00
IGCrystal
0ed646eb27 🐞 fix(WebUI): 修复安装插件按钮不可见 2025-06-23 19:41:56 +08:00
邹永赫
c0c5859c99 Merge pull request #1905 from zouyonghe/master
使用定义的Plain类型代替原始基础类型str,保持代码统一性
2025-06-23 18:52:56 +09:00
邹永赫
a47121b849 使用定义的Plain类型代替原始基础类型str,保持代码统一性 2025-06-23 18:49:47 +09:00
邹永赫
d9dd20e89a Merge pull request #1904 from zouyonghe/master
修复代码重构造成的无法向前兼容在node中发送简单文本信息的问题
2025-06-23 18:20:52 +09:00
邹永赫
ed4609ebe5 修复代码重构造成的无法向前兼容在node中发送简单文本信息的问题 2025-06-23 18:17:37 +09:00
Gao Jinzhe
e24225c828 Merge branch 'master' into master 2025-06-23 15:21:08 +08:00
IGCrystal
01ef86d658 feat: 对PlatformPage使用翻译键 2025-06-23 14:44:06 +08:00
Soulter
cd4802da04 Merge pull request #1902 from railgun19457/master
修复plugin_enable配置无法保存的问题
2025-06-23 13:30:31 +08:00
Misaka Mikoto
2aca65780f Merge branch 'AstrBotDevs:master' into master 2025-06-23 13:29:31 +08:00
Soulter
2c435f7387 Merge pull request #1899 from IGCrystal/branch-1
🐞 fix: 显示运行时长国际化
2025-06-23 13:21:59 +08:00
Soulter
cc1afd1a9c Merge pull request #1900 from AstrBotDevs/fix-hc-jwt
Fix: JWT secret issue
2025-06-23 13:16:08 +08:00
railgun19457
6f098cdba6 修复plugin_enable配置无法保存的问题 2025-06-23 13:06:46 +08:00
Soulter
d03e9fb90a fix: jwt secret 2025-06-23 12:36:11 +08:00
IGCrystal
9f2966abe9 Merge branch 'branch-1' of https://github.com/IGCrystal/AstrBot into branch-1 2025-06-23 12:09:10 +08:00
IGCrystal
4e28ea1883 🐞 fix: 显示运行时长国际化 2025-06-23 12:08:27 +08:00
Soulter
289214e85c Merge pull request #1898 from IGCrystal/branch-1
🐞 fix(WebUI): 修复platform的logo路径问题
2025-06-23 11:59:58 +08:00
IGCrystal
a20d98bf93 🐞 fix(WebUI): 修复platform的logo路径问题 2025-06-23 11:57:20 +08:00
Gao Jinzhe
50a296de20 Merge branch 'AstrBotDevs:master' into master 2025-06-20 14:39:57 +08:00
Gao Jinzhe
c79e38e044 Merge branch 'AstrBotDevs:master' into master 2025-06-17 20:29:32 +08:00
zhx
ccb95f803c feat: wechatpadpro 发送tts时 添加对mp3格式音频支持 2025-06-16 10:05:21 +08:00
Gao Jinzhe
dae745d925 Update server.py 2025-06-16 10:03:18 +08:00
Gao Jinzhe
791db65526 resolve conflict with master branch 2025-06-16 09:50:35 +08:00
Gao Jinzhe
02e2e617f5 Merge branch 'AstrBotDevs:master' into master 2025-06-15 22:04:06 +08:00
advent259141
bfc8024119 modified: astrbot/core/pipeline/process_stage/method/llm_request.py
new file:   astrbot/core/star/session_llm_manager.py
	modified:   astrbot/dashboard/routes/session_management.py
	modified:   dashboard/src/views/SessionManagementPage.vue

 增加了精确到会话的LLM启停管理以及插件启停管理
2025-06-14 03:42:21 +08:00
Gao Jinzhe
f26cf6ed6f Merge branch 'AstrBotDevs:master' into master 2025-06-14 03:03:41 +08:00
advent259141
f2be55bd8e Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 06:20:49 +08:00
advent259141
d241dd17ca Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 06:20:09 +08:00
advent259141
cecafdfe6c Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 03:54:35 +08:00
Soulter
6fecfd1a0e Merge pull request #1800 from AstrBotDevs/feat-weixinkefu-record
feat: 微信客服支持语音的收发
2025-06-13 03:52:15 +08:00
shuiping233
1ce95c473d fix : 在stage.py中专门对qq_official的会话控制器消息进行处理 2025-06-08 10:20:09 +08:00
shuiping233
eb365e398d fix: qq_official适配器使用SessionController(会话控制)功能时机器人回复消息无法发送到聊天平台 2025-06-08 10:20:09 +08:00
176 changed files with 10250 additions and 6309 deletions

View File

@@ -1,40 +1,56 @@
name: '🥳 发布插件' name: 🥳 发布插件
title: "[Plugin] 插件名"
description: 提交插件到插件市场 description: 提交插件到插件市场
labels: [ "plugin-publish" ] title: "[Plugin] 插件名"
labels: ["plugin-publish"]
assignees: []
body: body:
- type: markdown - type: markdown
attributes: attributes:
value: | value: |
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。 欢迎发布插件到插件市场!
- type: textarea
attributes:
label: 插件仓库
description: 插件的 GitHub 仓库链接
placeholder: >
如 https://github.com/Soulter/astrbot-github-cards
- type: textarea
attributes:
label: 描述
value: |
插件名:
插件作者:
插件简介:
支持的消息平台:(必填,如 QQ、微信、飞书)
标签:(可选)
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
- type: checkboxes
attributes:
label: Code of Conduct
options:
- label: >
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true
- type: markdown - type: markdown
attributes: attributes:
value: "❤️" value: |
## 插件基本信息
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
- type: textarea
id: plugin-info
attributes:
label: 插件信息
description: 请在下方代码块中填写您的插件信息确保反引号包裹了JSON
value: |
```json
{
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
validations:
required: true
- type: markdown
attributes:
value: |
## 检查
- type: checkboxes
id: checks
attributes:
label: 插件检查清单
description: 请确认以下所有项目
options:
- label: 我的插件经过完整的测试
required: true
- label: 我的插件不包含恶意代码
required: true
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true

63
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,63 @@
# AstrBot Development Instructions
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
## Working Effectively
### Bootstrap and Install Dependencies
- **Python 3.10+ required** - Check `.python-version` file
- Install UV package manager: `pip install uv`
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
- Create required directories: `mkdir -p data/plugins data/config data/temp`
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
- Navigate to dashboard: `cd dashboard`
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
- Dashboard creates optimized production build in `dashboard/dist/`
### Testing
- Do not generate test files for now.
### Code Quality and Linting
- Install ruff linter: `uv add --dev ruff`
- Check code style: `uv run ruff check .` -- takes <1 second
- Check formatting: `uv run ruff format --check .` -- takes <1 second
- Fix formatting: `uv run ruff format .`
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
### Common Issues and Workarounds
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
## CI/CD Integration
- GitHub Actions workflows in `.github/workflows/`
- Docker builds supported via `Dockerfile`
- Pre-commit hooks enforce ruff formatting and linting
## Docker Support
- Primary deployment method: `docker run soulter/astrbot:latest`
- Compose file available: `compose.yml`
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
- Volume mount required: `./data:/AstrBot/data`
## Multi-language Support
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
- UI supports internationalization
- Default language is Chinese
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.

13
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,13 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly

View File

@@ -13,7 +13,7 @@ jobs:
contents: write contents: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v5
- name: Dashboard Build - name: Dashboard Build
run: | run: |
@@ -70,10 +70,10 @@ jobs:
needs: build-and-publish-to-github-release needs: build-and-publish-to-github-release
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v5
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: '3.10' python-version: '3.10'

View File

@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v5
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL

View File

@@ -8,6 +8,7 @@ on:
- 'README.md' - 'README.md'
- 'changelogs/**' - 'changelogs/**'
- 'dashboard/**' - 'dashboard/**'
pull_request:
workflow_dispatch: workflow_dispatch:
jobs: jobs:
@@ -16,30 +17,29 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v5
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -r requirements.txt pip install pytest pytest-asyncio pytest-cov
pip install pytest pytest-cov pytest-asyncio pip install --editable .
- name: Run tests - name: Run tests
run: | run: |
mkdir data mkdir -p data/plugins
mkdir data/plugins mkdir -p data/config
mkdir data/config mkdir -p data/temp
mkdir data/temp
export TESTING=true export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }} export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov - name: Upload results to Codecov
uses: codecov/codecov-action@v4 uses: codecov/codecov-action@v5
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v5
- name: npm install, build - name: npm install, build
run: | run: |

View File

@@ -12,7 +12,7 @@ jobs:
steps: steps:
- name: Pull The Codes - name: Pull The Codes
uses: actions/checkout@v3 uses: actions/checkout@v5
with: with:
fetch-depth: 0 # Must be 0 so we can fetch tags fetch-depth: 0 # Must be 0 so we can fetch tags

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write pull-requests: write
steps: steps:
- uses: actions/stale@v5 - uses: actions/stale@v9
with: with:
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message' stale-issue-message: 'Stale issue message'

View File

@@ -16,7 +16,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a> <a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a> <a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) [![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=3600&style=for-the-badge&color=3b618e) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7日消息量&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> <a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
@@ -27,57 +27,50 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。 AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
> [!WARNING]
>
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
## ✨ 近期更新
<details><summary>1. AstrBot 现已自带知识库能力</summary>
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
![image](https://github.com/user-attachments/assets/28b639b0-bb5c-4958-8e94-92ae8cfd1ab4)
</details>
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能 ## ✨ 主要功能
> [!NOTE] 1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本! 2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper 4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富
2. **多消息平台接入**。支持接入 QQOneBot、QQ 官方机器人平台、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核 5. **WebUI**。可视化配置和管理机器人,功能齐全
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP]
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。
## ✨ 使用方式 ## ✨ 使用方式
#### Docker 部署 #### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。 请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### 1Panel 部署
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
#### 在 雨云 上部署
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### 在 Replit 上部署
社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### Windows 一键安装器部署 #### Windows 一键安装器部署
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。 请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### 宝塔面板部署
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### CasaOS 部署 #### CasaOS 部署
社区贡献的部署方式。 社区贡献的部署方式。
@@ -101,27 +94,14 @@ git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py uv run main.py
``` ```
或者,直接通过 uvx 安装 AstrBot
```bash
mkdir astrbot && cd astrbot
uvx astrbot init
# uvx astrbot run
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。 或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ⚡ 消息平台支持情况 ## ⚡ 消息平台支持情况
| 平台 | 支持性 | | 平台 | 支持性 |
| -------- | ------- | | -------- | ------- |
| QQ(官方机器人接口) | ✔ | | QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ | | QQ(OneBot) | ✔ |
| 微信个人号 | ✔ |
| Telegram | ✔ | | Telegram | ✔ |
| 企业微信 | ✔ | | 企业微信 | ✔ |
| 微信客服 | ✔ | | 微信客服 | ✔ |
@@ -140,7 +120,7 @@ uvx astrbot init
| 名称 | 支持性 | 类型 | 备注 | | 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- | | -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 | | OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | | | Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | | | Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | | | Dify | ✔ | LLMOps | |
@@ -148,6 +128,8 @@ uvx astrbot init
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 | | Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 | | LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 | | LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | | | 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | | | PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | | | OneAPI | ✔ | LLM 分发系统 | |
@@ -223,7 +205,7 @@ _✨ WebUI ✨_
此外,本项目的诞生离不开以下开源项目: 此外,本项目的诞生离不开以下开源项目:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy) - [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History ## ⭐ Star History
@@ -237,11 +219,8 @@ _✨ WebUI ✨_
</div> </div>
## Disclaimer ![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097)
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
_私は、高性能ですから!_ _私は、高性能ですから!_

View File

@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
## ✨ 主な機能 ## ✨ 主な機能
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。 1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechatFeishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。 2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。 3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。 4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。 5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## 免責事項 ## 免責事項
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。 1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
2. WeChat個人アカウントのデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません 2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
<!-- ## ✨ ATRI [ベータテスト] <!-- ## ✨ ATRI [ベータテスト]
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
4. TTS 4. TTS
--> -->
_私は、高性能ですから!_ _私は、高性能ですから!_

0
astrbot.lock Normal file
View File

View File

@@ -1 +1 @@
__version__ = "3.5.8" __version__ = "3.5.23"

View File

@@ -139,6 +139,14 @@ def conf():
- dashboard.password: Dashboard 密码 - dashboard.password: Dashboard 密码
- callback_api_base: 回调接口基址 - callback_api_base: 回调接口基址
可用子命令:
- set: 设置配置项值
- get: 获取配置项值
- login-info: 显示 Web 管理面板登录信息
""" """
pass pass
@@ -204,3 +212,44 @@ def get_config(key: str = None):
click.echo(f" {key}: {value}") click.echo(f" {key}: {value}")
except (KeyError, TypeError): except (KeyError, TypeError):
pass pass
@conf.command(name="login-info")
def get_login_info():
"""显示 Web 管理面板的登录信息
在 Docker 环境中使用示例:
docker exec -e ASTRBOT_ROOT=/AstrBot astrbot-container astrbot conf login-info
"""
config = _load_config()
try:
username = _get_nested_item(config, "dashboard.username")
# 注意我们不显示实际的MD5哈希密码而是提示用户如何重置
click.echo("🔐 Web 管理面板登录信息:")
click.echo(f" 用户名: {username}")
click.echo(" 密码: [已加密存储]")
click.echo()
click.echo("💡 如需重置密码,请使用以下命令:")
click.echo(" astrbot conf set dashboard.password <新密码>")
click.echo()
click.echo("🌐 访问地址:")
# 尝试获取端口信息
try:
port = _get_nested_item(config, "dashboard.port")
click.echo(f" http://localhost:{port}")
click.echo(f" http://your-server-ip:{port}")
except (KeyError, TypeError):
click.echo(" http://localhost:6185 (默认端口)")
click.echo(" http://your-server-ip:6185 (默认端口)")
click.echo()
click.echo("📋 Docker 环境使用说明:")
click.echo(" 如果在 Docker 中运行,请使用以下命令格式:")
click.echo(" docker exec -e ASTRBOT_ROOT=/AstrBot <容器名> astrbot conf login-info")
except KeyError:
click.echo("❌ 无法找到登录配置,请先运行 'astrbot init' 初始化")
except Exception as e:
raise click.UsageError(f"获取登录信息失败: {str(e)}")

View File

@@ -16,7 +16,13 @@ def check_astrbot_root(path: str | Path) -> bool:
def get_astrbot_root() -> Path: def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径""" """获取Astrbot根目录路径"""
return Path.cwd() import os
# 使用与core应用相同的路径解析逻辑优先使用ASTRBOT_ROOT环境变量
if path := os.environ.get("ASTRBOT_ROOT"):
return Path(path)
else:
return Path.cwd()
async def check_dashboard(astrbot_root: Path) -> None: async def check_dashboard(astrbot_root: Path) -> None:

View File

@@ -117,19 +117,24 @@ def build_plug_list(plugins_dir: Path) -> list:
# 从 metadata.yaml 加载元数据 # 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir) metadata = load_yaml_metadata(plugin_dir)
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
# 如果成功加载元数据,添加到结果列表 # 如果成功加载元数据,添加到结果列表
if metadata and all( if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"] k in metadata for k in ["name", "desc", "version", "author", "repo"]
): ):
result.append({ result.append(
"name": str(metadata.get("name", "")), {
"desc": str(metadata.get("desc", "")), "name": str(metadata.get("name", "")),
"version": str(metadata.get("version", "")), "desc": str(metadata.get("desc", "")),
"author": str(metadata.get("author", "")), "version": str(metadata.get("version", "")),
"repo": str(metadata.get("repo", "")), "author": str(metadata.get("author", "")),
"status": PluginStatus.INSTALLED, "repo": str(metadata.get("repo", "")),
"local_path": str(plugin_dir), "status": PluginStatus.INSTALLED,
}) "local_path": str(plugin_dir),
}
)
# 获取在线插件列表 # 获取在线插件列表
online_plugins = [] online_plugins = []
@@ -139,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list:
resp.raise_for_status() resp.raise_for_status()
data = resp.json() data = resp.json()
for plugin_id, plugin_info in data.items(): for plugin_id, plugin_info in data.items():
online_plugins.append({ online_plugins.append(
"name": str(plugin_id), {
"desc": str(plugin_info.get("desc", "")), "name": str(plugin_id),
"version": str(plugin_info.get("version", "")), "desc": str(plugin_info.get("desc", "")),
"author": str(plugin_info.get("author", "")), "version": str(plugin_info.get("version", "")),
"repo": str(plugin_info.get("repo", "")), "author": str(plugin_info.get("author", "")),
"status": PluginStatus.NOT_INSTALLED, "repo": str(plugin_info.get("repo", "")),
"local_path": None, "status": PluginStatus.NOT_INSTALLED,
}) "local_path": None,
}
)
except Exception as e: except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True) click.echo(f"获取在线插件列表失败: {e}", err=True)

View File

@@ -13,7 +13,6 @@ from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹 # 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True) os.makedirs(get_astrbot_data_path(), exist_ok=True)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
DEMO_MODE = os.getenv("DEMO_MODE", False) DEMO_MODE = os.getenv("DEMO_MODE", False)
astrbot_config = AstrBotConfig() astrbot_config = AstrBotConfig()
@@ -29,6 +28,3 @@ pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""), astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None), astrbot_config.get("pypi_index_url", None),
) )
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)

View File

@@ -3,15 +3,17 @@
""" """
import os import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "3.5.17" VERSION = "3.5.24"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
# 默认配置 # 默认配置
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"config_version": 2, "config_version": 2,
"platform_settings": { "platform_settings": {
"plugin_enable": {},
"unique_session": False, "unique_session": False,
"rate_limit": { "rate_limit": {
"time": 60, "time": 60,
@@ -52,6 +54,7 @@ DEFAULT_CONFIG = {
"wake_prefix": "", "wake_prefix": "",
"web_search": False, "web_search": False,
"web_search_link": False, "web_search_link": False,
"display_reasoning_text": False,
"identifier": False, "identifier": False,
"datetime_system_prompt": True, "datetime_system_prompt": True,
"default_personality": "default", "default_personality": "default",
@@ -59,8 +62,10 @@ DEFAULT_CONFIG = {
"max_context_length": -1, "max_context_length": -1,
"dequeue_context_length": 1, "dequeue_context_length": 1,
"streaming_response": False, "streaming_response": False,
"show_tool_use_status": False,
"streaming_segmented": False, "streaming_segmented": False,
"separate_provider": False, "separate_provider": True,
"max_agent_step": 30,
}, },
"provider_stt_settings": { "provider_stt_settings": {
"enable": False, "enable": False,
@@ -102,6 +107,7 @@ DEFAULT_CONFIG = {
"enable": True, "enable": True,
"username": "astrbot", "username": "astrbot",
"password": "77b90590a8945a7d36c963981a307dc9", "password": "77b90590a8945a7d36c963981a307dc9",
"jwt_secret": "",
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 6185, "port": 6185,
}, },
@@ -152,15 +158,6 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199, "ws_reverse_port": 6199,
"ws_reverse_token": "", "ws_reverse_token": "",
}, },
"微信个人号(Gewechat)": {
"id": "gwchat",
"type": "gewechat",
"enable": False,
"base_url": "http://localhost:2531",
"nickname": "soulter",
"host": "这里填写你的局域网IP或者公网服务器IP",
"port": 11451,
},
"微信个人号(WeChatPadPro)": { "微信个人号(WeChatPadPro)": {
"id": "wechatpadpro", "id": "wechatpadpro",
"type": "wechatpadpro", "type": "wechatpadpro",
@@ -313,8 +310,7 @@ CONFIG_METADATA_2 = {
"id": { "id": {
"description": "机器人名称", "description": "机器人名称",
"type": "string", "type": "string",
"obvious_hint": True, "hint": "机器人名称",
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
}, },
"type": { "type": {
"description": "适配器类型", "description": "适配器类型",
@@ -365,17 +361,16 @@ CONFIG_METADATA_2 = {
"description": "飞书机器人的名字", "description": "飞书机器人的名字",
"type": "string", "type": "string",
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", "hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
"obvious_hint": True,
}, },
"discord_token":{ "discord_token": {
"description": "Discord Bot Token", "description": "Discord Bot Token",
"type": "string", "type": "string",
"hint": "在此处填入你的Discord Bot Token" "hint": "在此处填入你的Discord Bot Token",
}, },
"discord_proxy":{ "discord_proxy": {
"description": "Discord 代理地址", "description": "Discord 代理地址",
"type": "string", "type": "string",
"hint": "可选的代理地址http://ip:port" "hint": "可选的代理地址http://ip:port",
}, },
"discord_command_register": { "discord_command_register": {
"description": "是否自动将插件指令注册为 Discord 斜杠指令", "description": "是否自动将插件指令注册为 Discord 斜杠指令",
@@ -386,10 +381,6 @@ CONFIG_METADATA_2 = {
"type": "string", "type": "string",
"hint": "可选的 Discord 活动名称。留空则不设置活动。", "hint": "可选的 Discord 活动名称。留空则不设置活动。",
}, },
"discord_guild_id_for_debug": {
"description": "【开发用】指定一个服务器(Guild)ID。在此服务器注册的指令会立刻生效便于调试。留空则注册为全局指令。",
"type": "string",
},
}, },
}, },
"platform_settings": { "platform_settings": {
@@ -442,7 +433,7 @@ CONFIG_METADATA_2 = {
"ignore_bot_self_message": { "ignore_bot_self_message": {
"description": "是否忽略机器人自身的消息", "description": "是否忽略机器人自身的消息",
"type": "bool", "type": "bool",
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人", "hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
}, },
"ignore_at_all": { "ignore_at_all": {
"description": "是否忽略 @ 全体成员", "description": "是否忽略 @ 全体成员",
@@ -485,13 +476,11 @@ CONFIG_METADATA_2 = {
"regex": { "regex": {
"description": "正则表达式", "description": "正则表达式",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)", "hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
}, },
"content_cleanup_rule": { "content_cleanup_rule": {
"description": "过滤分段后的内容", "description": "过滤分段后的内容",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)", "hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
}, },
}, },
@@ -514,7 +503,6 @@ CONFIG_METADATA_2 = {
"description": "ID 白名单", "description": "ID 白名单",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"obvious_hint": True,
"hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单", "hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
}, },
"id_whitelist_log": { "id_whitelist_log": {
@@ -544,7 +532,6 @@ CONFIG_METADATA_2 = {
"description": "路径映射", "description": "路径映射",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"obvious_hint": True,
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。", "hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
}, },
}, },
@@ -604,18 +591,19 @@ CONFIG_METADATA_2 = {
"config_template": { "config_template": {
"OpenAI": { "OpenAI": {
"id": "openai", "id": "openai",
"provider": "openai",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
"key": [], "key": [],
"api_base": "https://api.openai.com/v1", "api_base": "https://api.openai.com/v1",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"model": "gpt-4o-mini", "hint": "也兼容所有与OpenAI API兼容的服务。",
},
}, },
"Azure OpenAI": { "Azure OpenAI": {
"id": "azure", "id": "azure",
"provider": "azure",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -623,24 +611,23 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "", "api_base": "",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"model": "gpt-4o-mini",
},
}, },
"xAI": { "xAI": {
"id": "xai", "id": "xai",
"provider": "xai",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
"key": [], "key": [],
"api_base": "https://api.x.ai/v1", "api_base": "https://api.x.ai/v1",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {"model": "grok-2-latest", "temperature": 0.4},
"model": "grok-2-latest",
},
}, },
"Anthropic": { "Anthropic": {
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
"id": "claude", "id": "claude",
"provider": "anthropic",
"type": "anthropic_chat_completion", "type": "anthropic_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -650,21 +637,23 @@ CONFIG_METADATA_2 = {
"model_config": { "model_config": {
"model": "claude-3-5-sonnet-latest", "model": "claude-3-5-sonnet-latest",
"max_tokens": 4096, "max_tokens": 4096,
"temperature": 0.2,
}, },
}, },
"Ollama": { "Ollama": {
"hint": "启用前请确保已正确安装并运行 Ollama 服务端Ollama默认不带鉴权无需修改key",
"id": "ollama_default", "id": "ollama_default",
"provider": "ollama",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama "key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1", "api_base": "http://localhost:11434/v1",
"model_config": { "model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"model": "llama3.1-8b",
},
}, },
"LM Studio": { "LM Studio": {
"id": "lm_studio", "id": "lm_studio",
"provider": "lm_studio",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -676,6 +665,7 @@ CONFIG_METADATA_2 = {
}, },
"Gemini(OpenAI兼容)": { "Gemini(OpenAI兼容)": {
"id": "gemini_default", "id": "gemini_default",
"provider": "google",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -684,10 +674,12 @@ CONFIG_METADATA_2 = {
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {
"model": "gemini-1.5-flash", "model": "gemini-1.5-flash",
"temperature": 0.4,
}, },
}, },
"Gemini": { "Gemini": {
"id": "gemini_default", "id": "gemini_default",
"provider": "google",
"type": "googlegenai_chat_completion", "type": "googlegenai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -696,6 +688,7 @@ CONFIG_METADATA_2 = {
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {
"model": "gemini-2.0-flash-exp", "model": "gemini-2.0-flash-exp",
"temperature": 0.4,
}, },
"gm_resp_image_modal": False, "gm_resp_image_modal": False,
"gm_native_search": False, "gm_native_search": False,
@@ -713,18 +706,81 @@ CONFIG_METADATA_2 = {
}, },
"DeepSeek": { "DeepSeek": {
"id": "deepseek_default", "id": "deepseek_default",
"provider": "deepseek",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
"key": [], "key": [],
"api_base": "https://api.deepseek.com/v1", "api_base": "https://api.deepseek.com/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
},
"302.AI": {
"id": "302ai",
"provider": "302ai",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.302.ai/v1",
"timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
},
"硅基流动": {
"id": "siliconflow",
"provider": "siliconflow",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.siliconflow.cn/v1",
"model_config": { "model_config": {
"model": "deepseek-chat", "model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
}, },
}, },
"PPIO派欧云": {
"id": "ppio",
"provider": "ppio",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.ppinfra.com/v3/openai",
"timeout": 120,
"model_config": {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
},
},
"优云智算": {
"id": "compshare",
"provider": "compshare",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.modelverse.cn/v1",
"timeout": 120,
"model_config": {
"model": "moonshotai/Kimi-K2-Instruct",
},
},
"Kimi": {
"id": "moonshot",
"provider": "moonshot",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
},
"智谱 AI": { "智谱 AI": {
"id": "zhipu_default", "id": "zhipu_default",
"provider": "zhipu",
"type": "zhipu_chat_completion", "type": "zhipu_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -735,55 +791,9 @@ CONFIG_METADATA_2 = {
"model": "glm-4-flash", "model": "glm-4-flash",
}, },
}, },
"硅基流动": {
"id": "siliconflow",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.siliconflow.cn/v1",
"model_config": {
"model": "deepseek-ai/DeepSeek-V3",
},
},
"Kimi": {
"id": "moonshot",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {
"model": "moonshot-v1-8k",
},
},
"PPIO派欧云": {
"id": "ppio",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.ppinfra.com/v3/openai",
"timeout": 120,
"model_config": {
"model": "deepseek/deepseek-r1",
},
},
"LLMTuner": {
"id": "llmtuner_default",
"type": "llm_tuner",
"provider_type": "chat_completion",
"enable": True,
"base_model_path": "",
"adapter_model_path": "",
"llmtuner_template": "",
"finetuning_type": "lora",
"quantization_bit": 4,
},
"Dify": { "Dify": {
"id": "dify_app_default", "id": "dify_app_default",
"provider": "dify",
"type": "dify", "type": "dify",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -794,9 +804,11 @@ CONFIG_METADATA_2 = {
"dify_query_input_key": "astrbot_text_query", "dify_query_input_key": "astrbot_text_query",
"variables": {}, "variables": {},
"timeout": 60, "timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
}, },
"阿里云百炼应用": { "阿里云百炼应用": {
"id": "dashscope", "id": "dashscope",
"provider": "dashscope",
"type": "dashscope", "type": "dashscope",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -811,8 +823,20 @@ CONFIG_METADATA_2 = {
"variables": {}, "variables": {},
"timeout": 60, "timeout": 60,
}, },
"ModelScope": {
"id": "modelscope",
"provider": "modelscope",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
},
"FastGPT": { "FastGPT": {
"id": "fastgpt", "id": "fastgpt",
"provider": "fastgpt",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
@@ -822,6 +846,7 @@ CONFIG_METADATA_2 = {
}, },
"Whisper(API)": { "Whisper(API)": {
"id": "whisper", "id": "whisper",
"provider": "openai",
"type": "openai_whisper_api", "type": "openai_whisper_api",
"provider_type": "speech_to_text", "provider_type": "speech_to_text",
"enable": False, "enable": False,
@@ -830,16 +855,18 @@ CONFIG_METADATA_2 = {
"model": "whisper-1", "model": "whisper-1",
}, },
"Whisper(本地加载)": { "Whisper(本地加载)": {
"whisper_hint": "(不用修改我)", "hint": "启用前请 pip 安装 openai-whisper 库N卡用户大约下载 2GB主要是 torch 和 cudaCPU 用户大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
"provider": "openai",
"type": "openai_whisper_selfhost", "type": "openai_whisper_selfhost",
"provider_type": "speech_to_text", "provider_type": "speech_to_text",
"enable": False, "enable": False,
"id": "whisper", "id": "whisper_selfhost",
"model": "tiny", "model": "tiny",
}, },
"SenseVoice(本地加载)": { "SenseVoice(本地加载)": {
"sensevoice_hint": "(不用修改我)", "hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库默认使用CPU大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
"type": "sensevoice_stt_selfhost", "type": "sensevoice_stt_selfhost",
"provider": "sensevoice",
"provider_type": "speech_to_text", "provider_type": "speech_to_text",
"enable": False, "enable": False,
"id": "sensevoice", "id": "sensevoice",
@@ -849,6 +876,7 @@ CONFIG_METADATA_2 = {
"OpenAI TTS(API)": { "OpenAI TTS(API)": {
"id": "openai_tts", "id": "openai_tts",
"type": "openai_tts_api", "type": "openai_tts_api",
"provider": "openai",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": False, "enable": False,
"api_key": "", "api_key": "",
@@ -858,8 +886,9 @@ CONFIG_METADATA_2 = {
"timeout": "20", "timeout": "20",
}, },
"Edge TTS": { "Edge TTS": {
"edgetts_hint": "提示:使用这个服务前需要安装有 ffmpeg并且可以直接在终端调用 ffmpeg 指令。", "hint": "提示:使用这个服务前需要安装有 ffmpeg并且可以直接在终端调用 ffmpeg 指令。",
"id": "edge_tts", "id": "edge_tts",
"provider": "microsoft",
"type": "edge_tts", "type": "edge_tts",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": False, "enable": False,
@@ -869,6 +898,7 @@ CONFIG_METADATA_2 = {
"GSV TTS(本地加载)": { "GSV TTS(本地加载)": {
"id": "gsv_tts", "id": "gsv_tts",
"enable": False, "enable": False,
"provider": "gpt_sovits",
"type": "gsv_tts_selfhost", "type": "gsv_tts_selfhost",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"api_base": "http://127.0.0.1:9880", "api_base": "http://127.0.0.1:9880",
@@ -900,6 +930,7 @@ CONFIG_METADATA_2 = {
"GSVI TTS(API)": { "GSVI TTS(API)": {
"id": "gsvi_tts", "id": "gsvi_tts",
"type": "gsvi_tts_api", "type": "gsvi_tts_api",
"provider": "gpt_sovits_inference",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"api_base": "http://127.0.0.1:5000", "api_base": "http://127.0.0.1:5000",
"character": "", "character": "",
@@ -909,6 +940,7 @@ CONFIG_METADATA_2 = {
}, },
"FishAudio TTS(API)": { "FishAudio TTS(API)": {
"id": "fishaudio_tts", "id": "fishaudio_tts",
"provider": "fishaudio",
"type": "fishaudio_tts_api", "type": "fishaudio_tts_api",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": False, "enable": False,
@@ -919,6 +951,7 @@ CONFIG_METADATA_2 = {
}, },
"阿里云百炼 TTS(API)": { "阿里云百炼 TTS(API)": {
"id": "dashscope_tts", "id": "dashscope_tts",
"provider": "dashscope",
"type": "dashscope_tts", "type": "dashscope_tts",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": False, "enable": False,
@@ -930,6 +963,7 @@ CONFIG_METADATA_2 = {
"Azure TTS": { "Azure TTS": {
"id": "azure_tts", "id": "azure_tts",
"type": "azure_tts", "type": "azure_tts",
"provider": "azure",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": True, "enable": True,
"azure_tts_voice": "zh-CN-YunxiaNeural", "azure_tts_voice": "zh-CN-YunxiaNeural",
@@ -943,6 +977,7 @@ CONFIG_METADATA_2 = {
"MiniMax TTS(API)": { "MiniMax TTS(API)": {
"id": "minimax_tts", "id": "minimax_tts",
"type": "minimax_tts_api", "type": "minimax_tts_api",
"provider": "minimax",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": False, "enable": False,
"api_key": "", "api_key": "",
@@ -964,6 +999,7 @@ CONFIG_METADATA_2 = {
"火山引擎_TTS(API)": { "火山引擎_TTS(API)": {
"id": "volcengine_tts", "id": "volcengine_tts",
"type": "volcengine_tts", "type": "volcengine_tts",
"provider": "volcengine",
"provider_type": "text_to_speech", "provider_type": "text_to_speech",
"enable": False, "enable": False,
"api_key": "", "api_key": "",
@@ -974,20 +1010,35 @@ CONFIG_METADATA_2 = {
"api_base": "https://openspeech.bytedance.com/api/v1/tts", "api_base": "https://openspeech.bytedance.com/api/v1/tts",
"timeout": 20, "timeout": 20,
}, },
"Gemini TTS": {
"id": "gemini_tts",
"type": "gemini_tts",
"provider": "google",
"provider_type": "text_to_speech",
"enable": False,
"gemini_tts_api_key": "",
"gemini_tts_api_base": "",
"gemini_tts_timeout": 20,
"gemini_tts_model": "gemini-2.5-flash-preview-tts",
"gemini_tts_prefix": "",
"gemini_tts_voice_name": "Leda",
},
"OpenAI Embedding": { "OpenAI Embedding": {
"id": "openai_embedding", "id": "openai_embedding",
"type": "openai_embedding", "type": "openai_embedding",
"provider": "openai",
"provider_type": "embedding", "provider_type": "embedding",
"enable": True, "enable": True,
"embedding_api_key": "", "embedding_api_key": "",
"embedding_api_base": "", "embedding_api_base": "",
"embedding_model": "", "embedding_model": "",
"embedding_dimensions": 1536, "embedding_dimensions": 1024,
"timeout": 20, "timeout": 20,
}, },
"Gemini Embedding": { "Gemini Embedding": {
"id": "gemini_embedding", "id": "gemini_embedding",
"type": "gemini_embedding", "type": "gemini_embedding",
"provider": "google",
"provider_type": "embedding", "provider_type": "embedding",
"enable": True, "enable": True,
"embedding_api_key": "", "embedding_api_key": "",
@@ -998,17 +1049,19 @@ CONFIG_METADATA_2 = {
}, },
}, },
"items": { "items": {
"provider": {
"type": "string",
"invisible": True,
},
"gpt_weights_path": { "gpt_weights_path": {
"description": "GPT模型文件路径", "description": "GPT模型文件路径",
"type": "string", "type": "string",
"hint": "即“.ckpt”后缀的文件请使用绝对路径路径两端不要带双引号不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)", "hint": "即“.ckpt”后缀的文件请使用绝对路径路径两端不要带双引号不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
"obvious_hint": True,
}, },
"sovits_weights_path": { "sovits_weights_path": {
"description": "SoVITS模型文件路径", "description": "SoVITS模型文件路径",
"type": "string", "type": "string",
"hint": "即“.pth”后缀的文件请使用绝对路径路径两端不要带双引号不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)", "hint": "即“.pth”后缀的文件请使用绝对路径路径两端不要带双引号不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
"obvious_hint": True,
}, },
"gsv_default_parms": { "gsv_default_parms": {
"description": "GPT_SoVITS默认参数", "description": "GPT_SoVITS默认参数",
@@ -1019,13 +1072,11 @@ CONFIG_METADATA_2 = {
"description": "参考音频文件路径", "description": "参考音频文件路径",
"type": "string", "type": "string",
"hint": "必填!请使用绝对路径!路径两端不要带双引号!", "hint": "必填!请使用绝对路径!路径两端不要带双引号!",
"obvious_hint": True,
}, },
"gsv_prompt_text": { "gsv_prompt_text": {
"description": "参考音频文本", "description": "参考音频文本",
"type": "string", "type": "string",
"hint": "必填!请填写参考音频讲述的文本", "hint": "必填!请填写参考音频讲述的文本",
"obvious_hint": True,
}, },
"gsv_prompt_lang": { "gsv_prompt_lang": {
"description": "参考音频文本语言", "description": "参考音频文本语言",
@@ -1252,19 +1303,16 @@ CONFIG_METADATA_2 = {
"description": "启用原生搜索功能", "description": "启用原生搜索功能",
"type": "bool", "type": "bool",
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档", "hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
"obvious_hint": True,
}, },
"gm_native_coderunner": { "gm_native_coderunner": {
"description": "启用原生代码执行器", "description": "启用原生代码执行器",
"type": "bool", "type": "bool",
"hint": "启用后所有函数工具将全部失效", "hint": "启用后所有函数工具将全部失效",
"obvious_hint": True,
}, },
"gm_url_context": { "gm_url_context": {
"description": "启用URL上下文功能", "description": "启用URL上下文功能",
"type": "bool", "type": "bool",
"hint": "启用后所有函数工具将全部失效", "hint": "启用后所有函数工具将全部失效",
"obvious_hint": True,
}, },
"gm_safety_settings": { "gm_safety_settings": {
"description": "安全过滤器", "description": "安全过滤器",
@@ -1448,7 +1496,6 @@ CONFIG_METADATA_2 = {
"description": "部署SenseVoice", "description": "部署SenseVoice",
"type": "string", "type": "string",
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库默认使用CPU大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。", "hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库默认使用CPU大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
"obvious_hint": True,
}, },
"is_emotion": { "is_emotion": {
"description": "情绪识别", "description": "情绪识别",
@@ -1463,18 +1510,10 @@ CONFIG_METADATA_2 = {
"variables": { "variables": {
"description": "工作流固定输入变量", "description": "工作流固定输入变量",
"type": "object", "type": "object",
"obvious_hint": True,
"items": {}, "items": {},
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
"invisible": True, "invisible": True,
}, },
# "fastgpt_app_type": {
# "description": "应用类型",
# "type": "string",
# "hint": "FastGPT 应用的应用类型。",
# "options": ["agent", "workflow", "plugin"],
# "obvious_hint": True,
# },
"dashscope_app_type": { "dashscope_app_type": {
"description": "应用类型", "description": "应用类型",
"type": "string", "type": "string",
@@ -1485,7 +1524,6 @@ CONFIG_METADATA_2 = {
"dialog-workflow", "dialog-workflow",
"task-workflow", "task-workflow",
], ],
"obvious_hint": True,
}, },
"timeout": { "timeout": {
"description": "超时时间", "description": "超时时间",
@@ -1495,26 +1533,22 @@ CONFIG_METADATA_2 = {
"openai-tts-voice": { "openai-tts-voice": {
"description": "voice", "description": "voice",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'", "hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
}, },
"fishaudio-tts-character": { "fishaudio-tts-character": {
"description": "character", "description": "character",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问https://fish.audio/zh-CN/discovery", "hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问https://fish.audio/zh-CN/discovery",
}, },
"whisper_hint": { "whisper_hint": {
"description": "本地部署 Whisper 模型须知", "description": "本地部署 Whisper 模型须知",
"type": "string", "type": "string",
"hint": "启用前请 pip 安装 openai-whisper 库N卡用户大约下载 2GB主要是 torch 和 cudaCPU 用户大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。", "hint": "启用前请 pip 安装 openai-whisper 库N卡用户大约下载 2GB主要是 torch 和 cudaCPU 用户大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
"obvious_hint": True,
}, },
"id": { "id": {
"description": "ID", "description": "ID",
"type": "string", "type": "string",
"obvious_hint": True, "hint": "模型提供商名字。",
"hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。",
}, },
"type": { "type": {
"description": "模型提供商种类", "description": "模型提供商种类",
@@ -1529,53 +1563,27 @@ CONFIG_METADATA_2 = {
"enable": { "enable": {
"description": "启用", "description": "启用",
"type": "bool", "type": "bool",
"hint": "是否启用该模型。未启用的模型将不会被使用", "hint": "是否启用。",
}, },
"key": { "key": {
"description": "API Key", "description": "API Key",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"hint": "API Key 列表。填写好后输入回车即可添加 API Key。支持多个 API Key。", "hint": "提供商 API Key。",
}, },
"api_base": { "api_base": {
"description": "API Base URL", "description": "API Base URL",
"type": "string", "type": "string",
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1", "hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
"obvious_hint": True,
},
"base_model_path": {
"description": "基座模型路径",
"type": "string",
"hint": "基座模型路径。",
},
"adapter_model_path": {
"description": "Adapter 模型路径",
"type": "string",
"hint": "Adapter 模型路径。如 Lora",
},
"llmtuner_template": {
"description": "template",
"type": "string",
"hint": "基座模型的类型。如 llama3, qwen, 请参考 LlamaFactory 文档。",
},
"finetuning_type": {
"description": "微调类型",
"type": "string",
"hint": "微调类型。如 `lora`",
},
"quantization_bit": {
"description": "量化位数",
"type": "int",
"hint": "量化位数。如 4",
}, },
"model_config": { "model_config": {
"description": "文本生成模型", "description": "模型配置",
"type": "object", "type": "object",
"items": { "items": {
"model": { "model": {
"description": "模型名称", "description": "模型名称",
"type": "string", "type": "string",
"hint": "大语言模型名称,一般是小写的英文。如 gpt-4o-mini, deepseek-chat", "hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
}, },
"max_tokens": { "max_tokens": {
"description": "模型最大输出长度tokens", "description": "模型最大输出长度tokens",
@@ -1622,7 +1630,6 @@ CONFIG_METADATA_2 = {
"description": "启用大语言模型聊天", "description": "启用大语言模型聊天",
"type": "bool", "type": "bool",
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。", "hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
"obvious_hint": True,
}, },
"separate_provider": { "separate_provider": {
"description": "提供商会话隔离", "description": "提供商会话隔离",
@@ -1642,25 +1649,26 @@ CONFIG_METADATA_2 = {
"web_search": { "web_search": {
"description": "启用网页搜索", "description": "启用网页搜索",
"type": "bool", "type": "bool",
"obvious_hint": True,
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。", "hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
}, },
"web_search_link": { "web_search_link": {
"description": "网页搜索引用链接", "description": "网页搜索引用链接",
"type": "bool", "type": "bool",
"obvious_hint": True,
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。", "hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
}, },
"display_reasoning_text": {
"description": "显示思考内容",
"type": "bool",
"hint": "开启后,将在回复中显示模型的思考过程。",
},
"identifier": { "identifier": {
"description": "启动识别群员", "description": "启动识别群员",
"type": "bool", "type": "bool",
"obvious_hint": True,
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。", "hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
}, },
"datetime_system_prompt": { "datetime_system_prompt": {
"description": "启用日期时间系统提示", "description": "启用日期时间系统提示",
"type": "bool", "type": "bool",
"obvious_hint": True,
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。", "hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
}, },
"default_personality": { "default_personality": {
@@ -1688,10 +1696,19 @@ CONFIG_METADATA_2 = {
"type": "bool", "type": "bool",
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台", "hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
}, },
"show_tool_use_status": {
"description": "函数调用状态输出",
"type": "bool",
"hint": "在触发函数调用时输出其函数名和内容。",
},
"streaming_segmented": { "streaming_segmented": {
"description": "不支持流式回复的平台分段输出", "description": "不支持流式回复的平台分段输出",
"type": "bool", "type": "bool",
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项", "hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
},
"max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
}, },
}, },
}, },
@@ -1712,7 +1729,6 @@ CONFIG_METADATA_2 = {
"description": "人格名称", "description": "人格名称",
"type": "string", "type": "string",
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。", "hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
"obvious_hint": True,
}, },
"prompt": { "prompt": {
"description": "设定(系统提示词)", "description": "设定(系统提示词)",
@@ -1724,14 +1740,12 @@ CONFIG_METADATA_2 = {
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话", "hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
"obvious_hint": True,
}, },
"mood_imitation_dialogs": { "mood_imitation_dialogs": {
"description": "对话风格模仿", "description": "对话风格模仿",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话", "hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
"obvious_hint": True,
}, },
}, },
}, },
@@ -1743,7 +1757,6 @@ CONFIG_METADATA_2 = {
"description": "启用语音转文本(STT)", "description": "启用语音转文本(STT)",
"type": "bool", "type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。", "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
"obvious_hint": True,
}, },
"provider_id": { "provider_id": {
"description": "提供商 ID", "description": "提供商 ID",
@@ -1760,7 +1773,6 @@ CONFIG_METADATA_2 = {
"description": "启用文本转语音(TTS)", "description": "启用文本转语音(TTS)",
"type": "bool", "type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。", "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
"obvious_hint": True,
}, },
"provider_id": { "provider_id": {
"description": "提供商 ID", "description": "提供商 ID",
@@ -1771,7 +1783,6 @@ CONFIG_METADATA_2 = {
"description": "启用语音和文字双输出", "description": "启用语音和文字双输出",
"type": "bool", "type": "bool",
"hint": "启用后Bot 将同时输出语音和文字消息。", "hint": "启用后Bot 将同时输出语音和文字消息。",
"obvious_hint": True,
}, },
"use_file_service": { "use_file_service": {
"description": "使用文件服务提供 TTS 语音文件", "description": "使用文件服务提供 TTS 语音文件",
@@ -1787,25 +1798,21 @@ CONFIG_METADATA_2 = {
"group_icl_enable": { "group_icl_enable": {
"description": "群聊内记录各群员对话", "description": "群聊内记录各群员对话",
"type": "bool", "type": "bool",
"obvious_hint": True,
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。", "hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
}, },
"group_message_max_cnt": { "group_message_max_cnt": {
"description": "群聊消息最大数量", "description": "群聊消息最大数量",
"type": "int", "type": "int",
"obvious_hint": True,
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。", "hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
}, },
"image_caption": { "image_caption": {
"description": "群聊图像转述(需模型支持)", "description": "群聊图像转述(需模型支持)",
"type": "bool", "type": "bool",
"obvious_hint": True,
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。", "hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。",
}, },
"image_caption_provider_id": { "image_caption_provider_id": {
"description": "图像转述提供商 ID", "description": "图像转述提供商 ID",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。", "hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
}, },
"image_caption_prompt": { "image_caption_prompt": {
@@ -1819,14 +1826,12 @@ CONFIG_METADATA_2 = {
"enable": { "enable": {
"description": "启用主动回复", "description": "启用主动回复",
"type": "bool", "type": "bool",
"obvious_hint": True,
"hint": "启用后会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用", "hint": "启用后会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
}, },
"whitelist": { "whitelist": {
"description": "主动回复白名单", "description": "主动回复白名单",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"obvious_hint": True,
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。", "hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
}, },
"method": { "method": {
@@ -1838,13 +1843,11 @@ CONFIG_METADATA_2 = {
"possibility_reply": { "possibility_reply": {
"description": "回复概率", "description": "回复概率",
"type": "float", "type": "float",
"obvious_hint": True,
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。", "hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
}, },
"prompt": { "prompt": {
"description": "提示词", "description": "提示词",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。", "hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
}, },
}, },
@@ -1860,7 +1863,6 @@ CONFIG_METADATA_2 = {
"description": "机器人唤醒前缀", "description": "机器人唤醒前缀",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"obvious_hint": True,
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`则内置指令help等将需要通过您的唤醒前缀来触发。", "hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`则内置指令help等将需要通过您的唤醒前缀来触发。",
}, },
"t2i": { "t2i": {
@@ -1887,13 +1889,11 @@ CONFIG_METADATA_2 = {
"timezone": { "timezone": {
"description": "时区", "description": "时区",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab", "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
}, },
"callback_api_base": { "callback_api_base": {
"description": "对外可达的回调接口地址", "description": "对外可达的回调接口地址",
"type": "string", "type": "string",
"obvious_hint": True,
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址host因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185https://example.com 等。", "hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址host因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185https://example.com 等。",
}, },
"log_level": { "log_level": {
@@ -1941,90 +1941,3 @@ DEFAULT_VALUE_MAP = {
"list": [], "list": [],
"object": {}, "object": {},
} }
# "project_atri": {
# "description": "Project ATRI 配置",
# "type": "object",
# "items": {
# "enable": {"description": "启用", "type": "bool"},
# "long_term_memory": {
# "description": "长期记忆",
# "type": "object",
# "items": {
# "enable": {"description": "启用", "type": "bool"},
# "summary_threshold_cnt": {
# "description": "摘要阈值",
# "type": "int",
# "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。",
# },
# "embedding_provider_id": {
# "description": "Embedding provider ID",
# "type": "string",
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置",
# "obvious_hint": True,
# },
# "summarize_provider_id": {
# "description": "Summary provider ID",
# "type": "string",
# "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary请确保所填的 provider id 在 `配置页` 中存在。",
# "obvious_hint": True,
# },
# },
# },
# "active_message": {
# "description": "主动消息",
# "type": "object",
# "items": {
# "enable": {"description": "启用", "type": "bool"},
# },
# },
# "vision": {
# "description": "视觉理解",
# "type": "object",
# "items": {
# "enable": {"description": "启用", "type": "bool"},
# "provider_id_or_ofa_model_path": {
# "description": "提供商 ID 或 OFA 模型路径",
# "type": "string",
# "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。",
# },
# },
# },
# "split_response": {
# "description": "是否分割回复",
# "type": "bool",
# "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的时间间隔。默认启用。",
# },
# "persona": {
# "description": "人格",
# "type": "string",
# "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。",
# "obvious_hint": True,
# },
# "chat_provider_id": {
# "description": "Chat provider ID",
# "type": "string",
# "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。",
# "obvious_hint": True,
# },
# "chat_base_model_path": {
# "description": "用于聊天的基座模型路径",
# "type": "string",
# "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。",
# "obvious_hint": True,
# },
# "chat_adapter_model_path": {
# "description": "用于聊天的 Lora 模型路径",
# "type": "string",
# "hint": "Lora 模型路径。",
# "obvious_hint": True,
# },
# "quantization_bit": {
# "description": "量化位数",
# "type": "int",
# "hint": "模型量化位数。如果你不知道这是什么,请不要修改。默认为 4。",
# "obvious_hint": True,
# },
# },
# },

View File

@@ -88,7 +88,10 @@ class ConversationManager:
return self.session_conversations.get(unified_msg_origin, None) return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation( async def get_conversation(
self, unified_msg_origin: str, conversation_id: str self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation: ) -> Conversation:
"""获取会话的对话 """获取会话的对话
@@ -98,6 +101,13 @@ class ConversationManager:
Returns: Returns:
conversation (Conversation): 对话对象 conversation (Conversation): 对话对象
""" """
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
return self.db.get_conversation_by_user_id(
unified_msg_origin, conversation_id
)
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]: async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:

View File

@@ -46,9 +46,12 @@ class AstrBotCoreLifecycle:
self.astrbot_config = astrbot_config # 初始化配置 self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库 self.db = db # 初始化数据库
# 根据环境变量设置代理 # 设置代理
os.environ["https_proxy"] = self.astrbot_config["http_proxy"] if self.astrbot_config.get("http_proxy", ""):
os.environ["http_proxy"] = self.astrbot_config["http_proxy"] os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
if proxy := os.environ.get("https_proxy"):
logger.debug(f"Using proxy: {proxy}")
os.environ["no_proxy"] = "localhost" os.environ["no_proxy"] = "localhost"
async def initialize(self): async def initialize(self):

View File

@@ -2,6 +2,8 @@ import asyncio
import os import os
import uuid import uuid
import time import time
from urllib.parse import urlparse, unquote
import platform
class FileTokenService: class FileTokenService:
@@ -15,7 +17,9 @@ class FileTokenService:
async def _cleanup_expired_tokens(self): async def _cleanup_expired_tokens(self):
"""清理过期的令牌""" """清理过期的令牌"""
now = time.time() now = time.time()
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now] expired_tokens = [
token for token, (_, expire) in self.staged_files.items() if expire < now
]
for token in expired_tokens: for token in expired_tokens:
self.staged_files.pop(token, None) self.staged_files.pop(token, None)
@@ -32,15 +36,35 @@ class FileTokenService:
Raises: Raises:
FileNotFoundError: 当路径不存在时抛出 FileNotFoundError: 当路径不存在时抛出
""" """
# 处理 file:///
try:
parsed_uri = urlparse(file_path)
if parsed_uri.scheme == "file":
local_path = unquote(parsed_uri.path)
if platform.system() == "Windows" and local_path.startswith("/"):
local_path = local_path[1:]
else:
# 如果没有 file:/// 前缀,则认为是普通路径
local_path = file_path
except Exception:
# 解析失败时,按原路径处理
local_path = file_path
async with self.lock: async with self.lock:
await self._cleanup_expired_tokens() await self._cleanup_expired_tokens()
if not os.path.exists(file_path): if not os.path.exists(local_path):
raise FileNotFoundError(f"文件不存在: {file_path}") raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})"
)
file_token = str(uuid.uuid4()) file_token = str(uuid.uuid4())
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout) expire_time = time.time() + (
self.staged_files[file_token] = (file_path, expire_time) timeout if timeout is not None else self.default_timeout
)
# 存储转换后的真实路径
self.staged_files[file_token] = (local_path, expire_time)
return file_token return file_token
async def handle_file(self, file_token: str) -> str: async def handle_file(self, file_token: str) -> str:

View File

@@ -96,8 +96,6 @@ class LogBroker:
Queue: 订阅者的队列, 可用于接收日志消息 Queue: 订阅者的队列, 可用于接收日志消息
""" """
q = Queue(maxsize=CACHED_SIZE + 10) q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
self.subscribers.append(q) self.subscribers.append(q)
return q return q

View File

@@ -125,6 +125,9 @@ class Plain(BaseMessageComponent):
def toDict(self): def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}} return {"type": "text", "data": {"text": self.text.strip()}}
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent): class Face(BaseMessageComponent):
type: ComponentType = "Face" type: ComponentType = "Face"
@@ -610,6 +613,10 @@ class Node(BaseMessageComponent):
"data": {"file": f"base64://{bs64}"}, "data": {"file": f"base64://{bs64}"},
} }
) )
elif isinstance(comp, Plain):
# For Plain segments, we need to handle the plain differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, File): elif isinstance(comp, File):
# For File segments, we need to handle the file differently # For File segments, we need to handle the file differently
d = await comp.to_dict() d = await comp.to_dict()

View File

@@ -24,6 +24,8 @@ class MessageChain:
chain: List[BaseMessageComponent] = field(default_factory=list) chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置 use_t2i_: Optional[bool] = None # None 为跟随用户设置
type: Optional[str] = None
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
def message(self, message: str): def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。 """添加一条文本消息到消息链 `chain` 中。
@@ -98,6 +100,15 @@ class MessageChain:
self.chain.append(Image.fromFileSystem(path)) self.chain.append(Image.fromFileSystem(path))
return self return self
def base64_image(self, base64_str: str):
"""添加一条图片消息base64 编码字符串)到消息链 `chain` 中。
Example:
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
"""
self.chain.append(Image.fromBase64(base64_str))
return self
def use_t2i(self, use_t2i: bool): def use_t2i(self, use_t2i: bool):
"""设置是否使用文本转图片服务。 """设置是否使用文本转图片服务。
@@ -157,7 +168,7 @@ class ResultContentType(enum.Enum):
"""普通的消息结果""" """普通的消息结果"""
STREAMING_RESULT = enum.auto() STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果""" """调用 LLM 产生的流式结果"""
STREAMING_FINISH= enum.auto() STREAMING_FINISH = enum.auto()
"""流式输出完成""" """流式输出完成"""

View File

@@ -1,22 +1,24 @@
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageEventResult,
EventResultType, EventResultType,
MessageEventResult,
) )
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
# 管道阶段顺序 # 管道阶段顺序
STAGES_ORDER = [ STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒 "WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单 "WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制 "RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全 "ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性 "PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
@@ -29,6 +31,7 @@ STAGES_ORDER = [
__all__ = [ __all__ = [
"WakingCheckStage", "WakingCheckStage",
"WhitelistCheckStage", "WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage", "RateLimitStage",
"ContentSafetyCheckStage", "ContentSafetyCheckStage",
"PlatformCompatibilityStage", "PlatformCompatibilityStage",

View File

@@ -1,6 +1,14 @@
import inspect
import traceback
import typing as T
from dataclasses import dataclass from dataclasses import dataclass
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star import PluginManager from astrbot.core.star import PluginManager
from astrbot.api import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
@dataclass @dataclass
@@ -9,3 +17,97 @@ class PipelineContext:
astrbot_config: AstrBotConfig # AstrBot 配置对象 astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象 plugin_manager: PluginManager # 插件管理器对象
async def call_event_hook(
self,
event: AstrMessageEvent,
hook_type: EventType,
*args,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
"""
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()
async def call_handler(
self,
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -0,0 +1,58 @@
import abc
import typing as T
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
from enum import Enum, auto
class AgentState(Enum):
"""Agent 状态枚举"""
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData
class BaseAgentRunner:
@abc.abstractmethod
async def reset(self) -> None:
"""
Reset the agent to its initial state.
This method should be called before starting a new run.
"""
...
@abc.abstractmethod
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""
Process a single step of the agent.
"""
...
@abc.abstractmethod
def done(self) -> bool:
"""
Check if the agent has completed its task.
Returns True if the agent is done, False otherwise.
"""
...
@abc.abstractmethod
def get_final_llm_resp(self) -> LLMResponse | None:
"""
Get the final observation from the agent.
This method should be called after the agent is done.
"""
...

View File

@@ -0,0 +1,306 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
from ...context import PipelineContext
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageChain,
)
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from mcp.types import (
TextContent,
ImageContent,
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# TODO:
# 1. 处理平台不兼容的处理器
class ToolLoopAgent(BaseAgentRunner):
def __init__(
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
) -> None:
self.provider = provider
self.req = None
self.event = event
self.pipeline_ctx = pipeline_ctx
self._state = AgentState.IDLE
self.final_llm_resp = None
self.streaming = False
@override
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
self.req = req
self.streaming = streaming
self.final_llm_resp = None
self._state = AgentState.IDLE
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__)
async for resp in stream: # type: ignore
yield resp
else:
yield await self.provider.text_chat(**self.req.__dict__)
@override
async def step(self):
"""
Process a single step of the agent.
This method should return the result of the step.
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text)
),
)
continue
llm_resp_result = llm_response
break # got final response
if not llm_resp_result:
return
# 处理 LLM 响应
llm_resp = llm_resp_result
if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.ERROR)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
)
),
)
if not llm_resp.tools_call_name:
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# 执行事件钩子
if await self.pipeline_ctx.call_event_hook(
self.event, EventType.OnLLMResponseEvent, llm_resp
):
return
# 返回 LLM 结果
if llm_resp.result_chain:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=llm_resp.result_chain),
)
elif llm_resp.completion_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text)
),
)
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
yield AgentResponse(
type="tool_call_result",
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
)
self.req.append_tool_calls_result(tool_calls_result)
async def _handle_function_tools(
self,
req: ProviderRequest,
llm_response: LLMResponse,
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
"""处理函数工具调用。"""
tool_call_result_blocks: list[ToolCallMessageSegment] = []
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
# 执行函数调用
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if not res:
continue
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
else:
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
# 尝试调用工具函数
wrapper = self.pipeline_ctx.call_handler(
self.event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None:
# Tool 返回结果
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
yield MessageChain().message(resp)
else:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
self.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
# 处理函数调用响应
if tool_call_result_blocks:
yield tool_call_result_blocks
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -2,57 +2,47 @@
本地 Agent 模式的 LLM 调用 Stage 本地 Agent 模式的 LLM 调用 Stage
""" """
import traceback
import asyncio import asyncio
import copy
import json import json
from typing import Union, AsyncGenerator import traceback
from ...context import PipelineContext from typing import AsyncGenerator, Union
from ..stage import Stage from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult, MessageEventResult,
ResultContentType, ResultContentType,
MessageChain,
) )
from astrbot.core.message.components import Image from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger from astrbot.core.provider import Provider
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entities import ( from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse, LLMResponse,
ToolCallMessageSegment, ProviderRequest,
AssistantMessageSegment,
ToolCallsResult,
) )
from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import EventType
from mcp.types import ( from astrbot.core.utils.metrics import Metric
TextContent, from ...context import PipelineContext
ImageContent, from ..agent_runner.tool_loop_agent import ToolLoopAgent
EmbeddedResource, from ..stage import Stage
TextResourceContents,
BlobResourceContents,
)
from astrbot.core import web_chat_back_queue
class LLMRequestSubStage(Stage): class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None: async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx self.ctx = ctx
self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list conf = ctx.astrbot_config
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][ settings = conf["provider_settings"]
"wake_prefix" self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
] # str self.provider_wake_prefix: str = settings["wake_prefix"] # str
self.max_context_length = ctx.astrbot_config["provider_settings"][ self.max_context_length = settings["max_context_length"] # int
"max_context_length" self.dequeue_context_length: int = min(
] # int max(1, settings["dequeue_context_length"]),
self.dequeue_context_length = min(
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
self.max_context_length - 1, self.max_context_length - 1,
) # int )
self.streaming_response = ctx.astrbot_config["provider_settings"][ self.streaming_response: bool = settings["streaming_response"]
"streaming_response" self.max_step: int = settings.get("max_agent_step", 10)
] # bool self.show_tool_use: bool = settings.get("show_tool_use_status", True)
for bwp in self.bot_wake_prefixs: for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp): if self.provider_wake_prefix.startswith(bwp):
@@ -63,16 +53,33 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
if sel_provider and isinstance(sel_provider, str):
provider = _ctx.get_provider_by_id(sel_provider)
if not provider:
logger.error(f"未找到指定的提供商: {sel_provider}")
return provider
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def process( async def process(
self, event: AstrMessageEvent, _nested: bool = False self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]: ) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]: if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。") logger.debug("未启用 LLM 能力,跳过处理。")
return return
umo = event.unified_msg_origin
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo) # 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None: if provider is None:
return return
@@ -83,13 +90,12 @@ class LLMRequestSubStage(Stage):
) )
if req.conversation: if req.conversation:
all_contexts = json.loads(req.conversation.history) req.contexts = json.loads(req.conversation.history)
req.contexts = self._process_tool_message_pairs(
all_contexts, remove_tags=True
)
else: else:
req = ProviderRequest(prompt="", image_urls=[]) req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix: if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix): if not event.message_str.startswith(self.provider_wake_prefix):
return return
@@ -127,26 +133,8 @@ class LLMRequestSubStage(Stage):
return return
# 执行请求 LLM 前事件钩子。 # 执行请求 LLM 前事件钩子。
# 装饰 system_prompt 等功能 if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
# 获取当前平台ID return
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMRequestEvent, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if isinstance(req.contexts, str): if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts) req.contexts = json.loads(req.contexts)
@@ -176,77 +164,77 @@ class LLMRequestSubStage(Stage):
if not req.session_id: if not req.session_id:
req.session_id = event.unified_msg_origin req.session_id = event.unified_msg_origin
async def requesting(req: ProviderRequest): # fix messages
try: req.contexts = self.fix_messages(req.contexts)
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
final_llm_response = None # Call Agent
tool_loop_agent = ToolLoopAgent(
if self.streaming_response: provider=provider,
stream = provider.text_chat_stream(**req.__dict__) event=event,
async for llm_response in stream: pipeline_ctx=self.ctx,
if llm_response.is_chunk: )
if llm_response.result_chain: logger.debug(
yield llm_response.result_chain # MessageChain f"handle provider[id: {provider.provider_config['id']}] request: {req}"
else: )
yield MessageChain().message( await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
llm_response.completion_text
)
else:
final_llm_response = llm_response
else:
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM
if not final_llm_response:
raise Exception("LLM response is None.")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())
async def requesting():
step_idx = 0
while step_idx < self.max_step:
step_idx += 1
try:
async for resp in tool_loop_agent.step():
if event.is_stopped(): if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if (
self.show_tool_use
or event.get_platform_name() == "webchat"
):
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
continue
if self.streaming_response: if not self.streaming_response:
# 流式输出的处理 content_typ = (
async for result in self._handle_llm_stream_response( ResultContentType.LLM_RESULT
event, req, final_llm_response if resp.type == "llm_result"
): else ResultContentType.GENERAL_RESULT
if isinstance(result, ProviderRequest): )
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM event.set_result(
req = result MessageEventResult(
need_loop = True chain=resp.data["chain"].chain,
else: result_content_type=content_typ,
yield )
else: )
# 非流式输出的处理 yield
async for result in self._handle_llm_response( event.clear_result()
event, req, final_llm_response else:
): if resp.type == "streaming_delta":
if isinstance(result, ProviderRequest): yield resp.data["chain"] # MessageChain
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM if tool_loop_agent.done():
req = result break
need_loop = True
else:
yield
except Exception as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
asyncio.create_task( asyncio.create_task(
Metric.upload( Metric.upload(
llm_tick=1, llm_tick=1,
@@ -255,45 +243,41 @@ class LLMRequestSubStage(Stage):
) )
) )
# 保存到历史记录 if self.streaming_response:
await self._save_to_history(event, req, final_llm_response) # 流式响应
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
)
if not self.streaming_response:
event.set_extra("tool_call_result", None)
async for _ in requesting(req):
yield
else:
event.set_result( event.set_result(
MessageEventResult() MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT) .set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting(req)) .set_async_stream(requesting())
) )
# 这里使用yield来暂停当前阶段等待流式输出完成后继续处理
yield yield
if tool_loop_agent.done():
if event.get_extra("tool_call_result"): if final_llm_resp := tool_loop_agent.get_final_llm_resp():
event.set_result(event.get_extra("tool_call_result")) if final_llm_resp.completion_text:
event.set_extra("tool_call_result", None) chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
else:
chain = final_llm_resp.result_chain.chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
)
)
else:
async for _ in requesting():
yield yield
# 暂时直接发出去 # 异步处理 WebChat 特殊情况
if img_b64 := event.get_extra("tool_call_img_respond"):
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
event.set_extra("tool_call_img_respond", None)
if event.get_platform_name() == "webchat": if event.get_platform_name() == "webchat":
# 异步处理 WebChat 特殊情况 asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(self._handle_webchat(event, req))
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest): await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
conversation = await self.conv_manager.get_conversation( conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid event.unified_msg_origin, req.conversation.cid
@@ -303,21 +287,16 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:] latest_pair = messages[-2:]
if not latest_pair: if not latest_pair:
return return
provider = self.ctx.plugin_manager.context.get_using_provider()
cleaned_text = "User: " + latest_pair[0].get("content", "").strip() cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
# if len(latest_pair) > 1:
# cleaned_text += (
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
# )
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}") logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await provider.text_chat( llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.", system_prompt="You are expert in summarizing user's query.",
prompt=( prompt=(
f"Please summarize the following query of user:\n" f"Please summarize the following query of user:\n"
f"{cleaned_text}\n" f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text." "Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user." "You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `None`" "If you think the dialog is too short to summarize, only output a special mark: `<None>`"
), ),
) )
if llm_resp and llm_resp.completion_text: if llm_resp and llm_resp.completion_text:
@@ -325,7 +304,7 @@ class LLMRequestSubStage(Stage):
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}" f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
) )
title = llm_resp.completion_text.strip() title = llm_resp.completion_text.strip()
if not title or "None" == title: if not title or "<None>" in title:
return return
await self.conv_manager.update_conversation_title( await self.conv_manager.update_conversation_title(
event.unified_msg_origin, title=title event.unified_msg_origin, title=title
@@ -341,330 +320,50 @@ class LLMRequestSubStage(Stage):
cid=cid, cid=cid,
title=title, title=title,
) )
web_chat_back_queue.put_nowait(
{
"type": "update_title",
"cid": cid,
"data": title,
}
)
async def _handle_llm_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理非流式 LLM 响应。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.LLM_RESULT)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_llm_stream_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理流式 LLM 响应。
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.STREAMING_FINISH)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.STREAMING_FINISH)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_function_tools(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理函数工具调用。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
"""
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if res:
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
if isinstance(res.content[0], TextContent):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
elif isinstance(res.content[0], ImageContent):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
event.set_extra(
"tool_call_img_respond",
res.content[0].data,
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
event.set_extra(
"tool_call_img_respond",
res.content[0].data,
)
else:
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
else:
# 获取处理器,过滤掉平台不兼容的处理器
platform_id = event.get_platform_id()
star_md = star_map.get(func_tool.handler_module_path)
if (
star_md
and platform_id in star_md.supported_platforms
and not star_md.supported_platforms[platform_id]
):
logger.debug(
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
)
# 直接跳过不添加任何消息到tool_call_result
continue
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
else:
res = event.get_result()
if res and res.chain:
event.set_extra("tool_call_result", res)
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
async def _save_to_history( async def _save_to_history(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse | None,
): ):
if not req or not req.conversation or not llm_response: if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
return return
if llm_response.role == "assistant": # 历史上下文
# 文本回复 messages = copy.deepcopy(req.contexts)
contexts = req.contexts.copy() # 这一轮对话请求的用户输入
contexts.append(await req.assemble_context()) messages.append(await req.assemble_context())
# 这一轮对话的 LLM 响应
if req.tool_calls_result:
if not isinstance(req.tool_calls_result, list):
messages.extend(req.tool_calls_result.to_openai_messages())
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=messages
)
# 记录并标记函数调用结果 def fix_messages(self, messages: list[dict]) -> list[dict]:
if req.tool_calls_result: """验证并且修复上下文"""
tool_calls_messages = req.tool_calls_result.to_openai_messages() fixed_messages = []
for message in messages:
# 添加标记 if message.get("role") == "tool":
for message in tool_calls_messages: # tool block 前面必须要有 user 和 assistant block
message["_tool_call_history"] = True if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
processed_tool_messages = self._process_tool_message_pairs( # 我们直接将之前的上下文都清空
tool_calls_messages, remove_tags=False fixed_messages = []
) else:
fixed_messages.append(message)
contexts.extend(processed_tool_messages)
contexts.append(
{"role": "assistant", "content": llm_response.completion_text}
)
contexts_to_save = list(
filter(lambda item: "_no_save" not in item, contexts)
)
await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
)
def _process_tool_message_pairs(self, messages, remove_tags=True):
"""处理工具调用消息确保assistant和tool消息成对出现
Args:
messages (list): 消息列表
remove_tags (bool): 是否移除_tool_call_history标记
Returns:
list: 处理后的消息列表保证了assistant和对应tool消息的成对出现
"""
result = []
i = 0
while i < len(messages):
current_msg = messages[i]
# 普通消息直接添加
if "_tool_call_history" not in current_msg:
result.append(current_msg.copy() if remove_tags else current_msg)
i += 1
continue
# 工具调用消息成对处理
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
assistant_msg = current_msg.copy()
if remove_tags and "_tool_call_history" in assistant_msg:
del assistant_msg["_tool_call_history"]
related_tools = []
j = i + 1
while (
j < len(messages)
and messages[j].get("role") == "tool"
and "_tool_call_history" in messages[j]
):
tool_msg = messages[j].copy()
if remove_tags:
del tool_msg["_tool_call_history"]
related_tools.append(tool_msg)
j += 1
# 成对的时候添加到结果
if related_tools:
result.append(assistant_msg)
result.extend(related_tools)
i = j # 跳过已处理
else: else:
# 单独的tool消息 fixed_messages.append(message)
i += 1 return fixed_messages
return result

View File

@@ -50,7 +50,7 @@ class StarRequestSubStage(Stage):
logger.debug( logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}" f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
) )
wrapper = self._call_handler(self.ctx, event, handler.handler, **params) wrapper = self.ctx.call_handler(event, handler.handler, **params)
async for ret in wrapper: async for ret in wrapper:
yield ret yield ret
event.clear_result() # 清除上一个 handler 的结果 event.clear_result() # 清除上一个 handler 的结果

View File

@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@register_stage @register_stage
@@ -128,9 +129,7 @@ class RespondStage(Stage):
"streaming_segmented", False "streaming_segmented", False
) )
logger.info(f"应用流式输出({event.get_platform_name()})") logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send()
await event.send_streaming(result.async_stream, use_fallback) await event.send_streaming(result.async_stream, use_fallback)
await event._post_send()
return return
elif len(result.chain) > 0: elif len(result.chain) > 0:
# 检查路径映射 # 检查路径映射
@@ -141,8 +140,6 @@ class RespondStage(Stage):
component.file = path_Mapping(mappings, component.file) component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component event.get_result().chain[idx] = component
await event._pre_send()
# 检查消息链是否为空 # 检查消息链是否为空
try: try:
if await self._is_empty_message_chain(result.chain): if await self._is_empty_message_chain(result.chain):
@@ -158,9 +155,14 @@ class RespondStage(Stage):
c for c in result.chain if not isinstance(c, Comp.Record) c for c in result.chain if not isinstance(c, Comp.Record)
] ]
if self.enable_seg and ( if (
(self.only_llm_result and result.is_llm_result()) self.enable_seg
or not self.only_llm_result and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
)
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
): ):
decorated_comps = [] decorated_comps = []
if self.reply_with_mention: if self.reply_with_mention:
@@ -176,25 +178,26 @@ class RespondStage(Stage):
result.chain.remove(comp) result.chain.remove(comp)
break break
for rcomp in record_comps: # leverage lock to guarentee the order of message sending among different events
i = await self._calc_comp_interval(rcomp) async with session_lock_manager.acquire_lock(event.unified_msg_origin):
await asyncio.sleep(i) for rcomp in record_comps:
try: i = await self._calc_comp_interval(rcomp)
await event.send(MessageChain([rcomp])) await asyncio.sleep(i)
except Exception as e: try:
logger.error(f"发送消息失败: {e} chain: {result.chain}") await event.send(MessageChain([rcomp]))
break except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
# 分段回复 break
for comp in non_record_comps: # 分段回复
i = await self._calc_comp_interval(comp) for comp in non_record_comps:
await asyncio.sleep(i) i = await self._calc_comp_interval(comp)
try: await asyncio.sleep(i)
await event.send(MessageChain([*decorated_comps, comp])) try:
decorated_comps = [] # 清空已发送的装饰组件 await event.send(MessageChain([*decorated_comps, comp]))
except Exception as e: decorated_comps = [] # 清空已发送的装饰组件
logger.error(f"发送消息失败: {e} chain: {result.chain}") except Exception as e:
break logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
else: else:
for rcomp in record_comps: for rcomp in record_comps:
try: try:
@@ -208,7 +211,6 @@ class RespondStage(Stage):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}") logger.error(f"发送消息失败: {e} chain: {result.chain}")
await event._post_send()
logger.info( logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}" f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
) )

View File

@@ -3,11 +3,12 @@ import time
import traceback import traceback
from typing import AsyncGenerator, Union from typing import AsyncGenerator, Union
from astrbot.core import html_renderer, logger, file_token_service from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star import star_map from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -141,7 +142,11 @@ class ResultDecorateStage(Stage):
break break
# 分段回复 # 分段回复
if self.enable_segmented_reply: if self.enable_segmented_reply and event.get_platform_name() not in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
if ( if (
self.only_llm_result and result.is_llm_result() self.only_llm_result and result.is_llm_result()
) or not self.only_llm_result: ) or not self.only_llm_result:
@@ -172,10 +177,12 @@ class ResultDecorateStage(Stage):
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider( tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
event.unified_msg_origin event.unified_msg_origin
) )
if ( if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"] self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result() and result.is_llm_result()
and tts_provider and tts_provider
and SessionServiceManager.should_process_tts_request(event)
): ):
new_chain = [] new_chain = []
for comp in result.chain: for comp in result.chain:

View File

@@ -73,7 +73,7 @@ class PipelineScheduler:
await self._process_stages(event) await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if not event._has_send_oper and event.get_platform_name() == "webchat": if event.get_platform_name() == "webchat":
await event.send(None) await event.send(None)
logger.debug("pipeline 执行完毕。") logger.debug("pipeline 执行完毕。")

View File

@@ -0,0 +1,22 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core import logger
@register_stage
class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
event.stop_event()

View File

@@ -1,12 +1,8 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import inspect from typing import List, AsyncGenerator, Union
import traceback
from astrbot.api import logger
from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类 registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
@@ -41,70 +37,3 @@ class Stage(abc.ABC):
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段) Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
""" """
raise NotImplementedError raise NotImplementedError
async def _call_handler(
self,
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型每次yield都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 待处理的事件对象
handler (Awaitable): 事件处理函数
*args: 传递给handler的位置参数
**kwargs: 传递给handler的关键字参数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器(async def)
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
# 如果是一个异步生成器, 进入洋葱模型
_has_yielded = False # 是否返回过值
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield # 传递控制权给上一层的process函数
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret # 传递控制权给上一层的process函数
if not _has_yielded:
# 如果这个异步生成器没有执行到yield分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield # 传递控制权给上一层的process函数
else:
yield ret # 传递控制权给上一层的process函数

View File

@@ -1,13 +1,16 @@
from ..stage import Stage, register_stage from typing import AsyncGenerator, Union
from ..context import PipelineContext
from astrbot import logger from astrbot import logger
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At, AtAll, Reply from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.star.star import star_map from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage
@register_stage @register_stage
@@ -135,7 +138,6 @@ class WakingCheckStage(Stage):
f"插件 {star_map[handler.handler_module_path].name}: {e}" f"插件 {star_map[handler.handler_module_path].name}: {e}"
) )
) )
await event._post_send()
event.stop_event() event.stop_event()
passed = False passed = False
break break
@@ -150,7 +152,6 @@ class WakingCheckStage(Stage):
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。" f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
) )
) )
await event._post_send()
logger.info( logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。" f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
) )
@@ -166,7 +167,12 @@ class WakingCheckStage(Stage):
"parsed_params" "parsed_params"
) )
event.clear_extra() event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
event, activated_handlers
)
event.set_extra("activated_handlers", activated_handlers) event.set_extra("activated_handlers", activated_handlers)
event.set_extra("handlers_parsed_params", handlers_parsed_params) event.set_extra("handlers_parsed_params", handlers_parsed_params)

View File

@@ -227,7 +227,7 @@ class AstrMessageEvent(abc.ABC):
): ):
"""发送流式消息到消息平台,使用异步生成器。 """发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊。 目前仅支持: telegramqq official 私聊。
Fallback仅支持 aiocqhttp, gewechat Fallback仅支持 aiocqhttp。
""" """
asyncio.create_task( asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
@@ -235,10 +235,10 @@ class AstrMessageEvent(abc.ABC):
self._has_send_oper = True self._has_send_oper = True
async def _pre_send(self): async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法""" """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
async def _post_send(self): async def _post_send(self):
"""调度器会在执行 send() 后调用该方法""" """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
def set_result(self, result: Union[MessageEventResult, str]): def set_result(self, result: Union[MessageEventResult, str]):
"""设置消息事件的结果。 """设置消息事件的结果。
@@ -419,7 +419,6 @@ class AstrMessageEvent(abc.ABC):
适配情况: 适配情况:
- gewechat
- aiocqhttp(OneBotv11) - aiocqhttp(OneBotv11)
""" """
... ...

View File

@@ -58,10 +58,6 @@ class PlatformManager:
from .sources.qqofficial_webhook.qo_webhook_adapter import ( from .sources.qqofficial_webhook.qo_webhook_adapter import (
QQOfficialWebhookPlatformAdapter, # noqa: F401 QQOfficialWebhookPlatformAdapter, # noqa: F401
) )
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import (
GewechatPlatformAdapter, # noqa: F401
)
case "wechatpadpro": case "wechatpadpro":
from .sources.wechatpadpro.wechatpadpro_adapter import ( from .sources.wechatpadpro.wechatpadpro_adapter import (
WeChatPadProAdapter, # noqa: F401 WeChatPadProAdapter, # noqa: F401

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import re import re
from typing import AsyncGenerator, Dict, List from typing import AsyncGenerator, Dict, List
from aiocqhttp import CQHttp from aiocqhttp import CQHttp, Event
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import ( from astrbot.api.message_components import (
Image, Image,
@@ -58,50 +58,85 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
ret.append(d) ret.append(d)
return ret return ret
async def send(self, message: MessageChain): @classmethod
async def _dispatch_send(
cls,
bot: CQHttp,
event: Event | None,
is_group: bool,
session_id: str,
messages: list[dict],
):
if event:
await bot.send(event=event, message=messages)
elif is_group:
await bot.send_group_msg(group_id=session_id, message=messages)
else:
await bot.send_private_msg(user_id=session_id, message=messages)
@classmethod
async def send_message(
cls,
bot: CQHttp,
message_chain: MessageChain,
event: Event | None = None,
is_group: bool = False,
session_id: str = None,
):
"""发送消息"""
# 转发消息、文件消息不能和普通消息混在一起发送 # 转发消息、文件消息不能和普通消息混在一起发送
send_one_by_one = any( send_one_by_one = any(
isinstance(seg, (Node, Nodes, File)) for seg in message.chain isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
) )
if send_one_by_one: if not send_one_by_one:
for seg in message.chain: ret = await cls._parse_onebot_json(message_chain)
if isinstance(seg, (Node, Nodes)):
# 合并转发消息
if isinstance(seg, Node):
nodes = Nodes([seg])
seg = nodes
payload = await seg.to_dict()
if self.get_group_id():
payload["group_id"] = self.get_group_id()
await self.bot.call_action("send_group_forward_msg", **payload)
else:
payload["user_id"] = self.get_sender_id()
await self.bot.call_action(
"send_private_forward_msg", **payload
)
elif isinstance(seg, File):
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
await self.bot.send(
self.message_obj.raw_message,
[d],
)
else:
await self.bot.send(
self.message_obj.raw_message,
await AiocqhttpMessageEvent._parse_onebot_json(
MessageChain([seg])
),
)
await asyncio.sleep(0.5)
else:
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if not ret: if not ret:
return return
await self.bot.send(self.message_obj.raw_message, ret) await cls._dispatch_send(bot, event, is_group, session_id, ret)
return
for seg in message_chain.chain:
if isinstance(seg, (Node, Nodes)):
# 合并转发消息
if isinstance(seg, Node):
nodes = Nodes([seg])
seg = nodes
payload = await seg.to_dict()
if is_group:
payload["group_id"] = session_id
await bot.call_action("send_group_forward_msg", **payload)
else:
payload["user_id"] = session_id
await bot.call_action("send_private_forward_msg", **payload)
elif isinstance(seg, File):
d = await cls._from_segment_to_dict(seg)
await cls._dispatch_send(bot, event, is_group, session_id, [d])
else:
messages = await cls._parse_onebot_json(MessageChain([seg]))
if not messages:
continue
await cls._dispatch_send(bot, event, is_group, session_id, messages)
await asyncio.sleep(0.5)
async def send(self, message: MessageChain):
"""发送消息"""
event = self.message_obj.raw_message
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
is_group = False
if self.get_group_id():
is_group = True
session_id = self.get_group_id()
else:
session_id = self.get_sender_id()
await self.send_message(
bot=self.bot,
message_chain=message,
event=event,
is_group=is_group,
session_id=session_id,
)
await super().send(message) await super().send(message)
async def send_streaming( async def send_streaming(

View File

@@ -83,19 +83,18 @@ class AiocqhttpAdapter(Platform):
async def send_by_session( async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain self, session: MessageSesion, message_chain: MessageChain
): ):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain) is_group = session.message_type == MessageType.GROUP_MESSAGE
match session.message_type.value: if is_group:
case MessageType.GROUP_MESSAGE.value: session_id = session.session_id.split("_")[-1]
if "_" in session.session_id: else:
# 独立会话 session_id = session.session_id
_, group_id = session.session_id.split("_") await AiocqhttpMessageEvent.send_message(
await self.bot.send_group_msg(group_id=group_id, message=ret) bot=self.bot,
else: message_chain=message_chain,
await self.bot.send_group_msg( event=None, # 这里不需要 event因为是通过 session 发送的
group_id=session.session_id, message=ret is_group=is_group,
) session_id=session_id,
case MessageType.FRIEND_MESSAGE.value: )
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain) await super().send_by_session(session, message_chain)
async def convert_message(self, event: Event) -> AstrBotMessage: async def convert_message(self, event: Event) -> AstrBotMessage:
@@ -168,9 +167,7 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event: if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event: if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append( abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
Poke(qq=str(event["target_id"]), type="poke")
) # noqa: F405
return abm return abm
@@ -273,8 +270,16 @@ class AiocqhttpAdapter(Platform):
action="get_msg", action="get_msg",
message_id=int(m["data"]["id"]), message_id=int(m["data"]["id"]),
) )
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
reply_event_data["post_type"] = "message"
new_event = Event.from_payload(reply_event_data)
if not new_event:
logger.error(
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
)
continue
abm_reply = await self._convert_handle_message_event( abm_reply = await self._convert_handle_message_event(
Event.from_payload(reply_event_data), get_reply=False new_event, get_reply=False
) )
reply_seg = Reply( reply_seg = Reply(
@@ -307,7 +312,9 @@ class AiocqhttpAdapter(Platform):
user_id=int(m["data"]["qq"]), user_id=int(m["data"]["qq"]),
) )
if at_info: if at_info:
nickname = at_info.get("nick", "") nickname = at_info.get("nick", "") or at_info.get(
"nickname", ""
)
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"} is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
abm.message.append( abm.message.append(
@@ -322,7 +329,7 @@ class AiocqhttpAdapter(Platform):
first_at_self_processed = True first_at_self_processed = True
else: else:
# 非第一个@机器人或@其他用户添加到message_str # 非第一个@机器人或@其他用户添加到message_str
message_str += f" @{nickname} " message_str += f" @{nickname}({m['data']['qq']}) "
else: else:
abm.message.append(At(qq=str(m["data"]["qq"]), name="")) abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
except ActionFailed as e: except ActionFailed as e:

View File

@@ -57,6 +57,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
logger.error(f"钉钉图片处理失败: {e}") logger.error(f"钉钉图片处理失败: {e}")
logger.warning(f"跳过图片发送: {image_path}") logger.warning(f"跳过图片发送: {image_path}")
continue continue
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
await self.send_with_client(self.client, message) await self.send_with_client(self.client, message)
await super().send(message) await super().send(message)

View File

@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
await self.on_ready_once_callback() await self.on_ready_once_callback()
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True) f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
)
def _create_message_data(self, message: discord.Message) -> dict: def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典""" """从 discord.Message 创建数据字典"""
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
message_data = self._create_message_data(message) message_data = self._create_message_data(message)
await self.on_message_received(message_data) await self.on_message_received(message_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str: def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容""" """从交互中提取内容"""
interaction_type = interaction.type interaction_type = interaction.type

View File

@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
self.url = url self.url = url
self.disabled = disabled self.disabled = disabled
class DiscordReference(BaseMessageComponent): class DiscordReference(BaseMessageComponent):
"""Discord引用组件""" """Discord引用组件"""
type: str = "discord_reference" type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str): def __init__(self, message_id: str, channel_id: str):
self.message_id = message_id self.message_id = message_id
self.channel_id = channel_id self.channel_id = channel_id
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
self.components = components or [] self.components = components or []
self.timeout = timeout self.timeout = timeout
def to_discord_view(self) -> discord.ui.View: def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象""" """转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout) view = discord.ui.View(timeout=self.timeout)

View File

@@ -46,6 +46,8 @@ class DiscordPlatformAdapter(Platform):
self.enable_command_register = self.config.get("discord_command_register", True) self.enable_command_register = self.config.get("discord_command_register", True)
self.guild_id = self.config.get("discord_guild_id_for_debug", None) self.guild_id = self.config.get("discord_guild_id_for_debug", None)
self.activity_name = self.config.get("discord_activity_name", None) self.activity_name = self.config.get("discord_activity_name", None)
self.shutdown_event = asyncio.Event()
self._polling_task = None
@override @override
async def send_by_session( async def send_by_session(
@@ -137,7 +139,8 @@ class DiscordPlatformAdapter(Platform):
self.client.on_ready_once_callback = callback self.client.on_ready_once_callback = callback
try: try:
await self.client.start_polling() self._polling_task = asyncio.create_task(self.client.start_polling())
await self.shutdown_event.wait()
except discord.errors.LoginFailure: except discord.errors.LoginFailure:
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。") logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
except discord.errors.ConnectionClosed: except discord.errors.ConnectionClosed:
@@ -162,42 +165,47 @@ class DiscordPlatformAdapter(Platform):
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage""" """将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"] message: discord.Message = data["message"]
is_mentioned = data.get("is_mentioned", False)
content = message.content content = message.content
# 如果机器人被@,移除@部分 # 如果机器人被@,移除@部分
if ( # 剥离 User Mention (<@id>, <@!id>)
is_mentioned if self.client and self.client.user:
and self.client
and self.client.user
and self.client.user in message.mentions
):
# 构建机器人的@字符串,格式为 <@USER_ID> 或 <@!USER_ID>
mention_str = f"<@{self.client.user.id}>" mention_str = f"<@{self.client.user.id}>"
mention_str_nickname = ( mention_str_nickname = f"<@!{self.client.user.id}>"
f"<@!{self.client.user.id}>" # 有些客户端会使用带!的格式
)
if content.startswith(mention_str): if content.startswith(mention_str):
content = content[len(mention_str) :].lstrip() content = content[len(mention_str) :].lstrip()
elif content.startswith(mention_str_nickname): elif content.startswith(mention_str_nickname):
content = content[len(mention_str_nickname) :].lstrip() content = content[len(mention_str_nickname) :].lstrip()
abm = AstrBotMessage() # 剥离 Role Mentionbot 拥有的任一角色被提及,<@&role_id>
if (
hasattr(message, "role_mentions")
and hasattr(message, "guild")
and message.guild
):
bot_member = (
message.guild.get_member(self.client.user.id)
if self.client and self.client.user
else None
)
if bot_member and hasattr(bot_member, "roles"):
for role in bot_member.roles:
role_mention_str = f"<@&{role.id}>"
if content.startswith(role_mention_str):
content = content[len(role_mention_str) :].lstrip()
break # 只剥离第一个匹配的角色 mention
abm = AstrBotMessage()
abm.type = self._get_message_type(message.channel) abm.type = self._get_message_type(message.channel)
abm.group_id = self._get_channel_id(message.channel) abm.group_id = self._get_channel_id(message.channel)
abm.message_str = content abm.message_str = content
abm.sender = MessageMember( abm.sender = MessageMember(
user_id=str(message.author.id), nickname=message.author.display_name user_id=str(message.author.id), nickname=message.author.display_name
) )
message_chain = [] message_chain = []
if abm.message_str: if abm.message_str:
message_chain.append(Plain(text=abm.message_str)) message_chain.append(Plain(text=abm.message_str))
if message.attachments: if message.attachments:
for attachment in message.attachments: for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith( if attachment.content_type and attachment.content_type.startswith(
@@ -210,7 +218,6 @@ class DiscordPlatformAdapter(Platform):
message_chain.append( message_chain.append(
File(name=attachment.filename, url=attachment.url) File(name=attachment.filename, url=attachment.url)
) )
abm.message = message_chain abm.message = message_chain
abm.raw_message = message abm.raw_message = message
abm.self_id = self.client_self_id abm.self_id = self.client_self_id
@@ -237,13 +244,35 @@ class DiscordPlatformAdapter(Platform):
# 检查是否为斜杠指令 # 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None is_slash_command = message_event.interaction_followup_webhook is not None
# 检查是否被@ # 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = ( is_mention = False
# User Mention
if (
self.client self.client
and self.client.user and self.client.user
and hasattr(message.raw_message, "mentions") and hasattr(message.raw_message, "mentions")
and self.client.user in message.raw_message.mentions ):
) if self.client.user in message.raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"):
bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
try:
bot_member = message.raw_message.guild.get_member(
self.client.user.id
)
except Exception:
bot_member = None
if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions)
if (
bot_roles
and mentioned_roles
and bot_roles.intersection(mentioned_roles)
):
is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态 # 如果是斜杠指令或被@的消息,设置为唤醒状态
if is_slash_command or is_mention: if is_slash_command or is_mention:
@@ -255,23 +284,37 @@ class DiscordPlatformAdapter(Platform):
@override @override
async def terminate(self): async def terminate(self):
"""终止适配器""" """终止适配器"""
logger.info("[Discord] 正在终止适配器...") logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
self.shutdown_event.set()
# 优先 cancel polling_task
if self._polling_task:
self._polling_task.cancel()
try:
await asyncio.wait_for(self._polling_task, timeout=10)
except asyncio.CancelledError:
logger.info("[Discord] polling_task 已取消。")
except Exception as e:
logger.warning(f"[Discord] polling_task 取消异常: {e}")
logger.info("[Discord] 正在清理已注册的斜杠指令... (step 2)")
# 清理指令 # 清理指令
if self.enable_command_register and self.client: if self.enable_command_register and self.client:
logger.info("[Discord] 正在清理已注册的斜杠指令...")
try: try:
# 传入空的列表来清除所有全局指令 await asyncio.wait_for(
# 如果指定了 guild_id则只清除该服务器的指令 self.client.sync_commands(
await self.client.sync_commands( commands=[],
commands=[], guild_ids=[self.guild_id] if self.guild_id else None guild_ids=[self.guild_id] if self.guild_id else None,
),
timeout=10,
) )
logger.info("[Discord] 指令清理完成。") logger.info("[Discord] 指令清理完成。")
except Exception as e: except Exception as e:
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True) logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
logger.info("[Discord] 正在关闭 Discord 客户端... (step 3)")
if self.client and hasattr(self.client, "close"): if self.client and hasattr(self.client, "close"):
await self.client.close() try:
await asyncio.wait_for(self.client.close(), timeout=10)
except Exception as e:
logger.warning(f"[Discord] 客户端关闭异常: {e}")
logger.info("[Discord] 适配器已终止。") logger.info("[Discord] 适配器已终止。")
def register_handler(self, handler_info): def register_handler(self, handler_info):

View File

@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
# 解析消息链为 Discord 所需的对象 # 解析消息链为 Discord 所需的对象
try: try:
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message) (
content,
files,
view,
embeds,
reference_message_id,
) = await self._parse_to_discord(message)
except Exception as e: except Exception as e:
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True) logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
return return
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
if await asyncio.to_thread(path.exists): if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes) file_bytes = await asyncio.to_thread(path.read_bytes)
files.append( files.append(
discord.File(BytesIO(file_bytes), discord.File(BytesIO(file_bytes), filename=i.name)
filename=i.name)
) )
else: else:
logger.warning( logger.warning(

View File

@@ -1,812 +0,0 @@
import asyncio
import base64
import datetime
import os
import re
import uuid
import threading
import aiohttp
import anyio
import quart
from astrbot.api import logger, sp
from astrbot.api.message_components import Plain, Image, At, Record, Video
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.core.utils.io import download_image_by_url
from .downloader import GeweDownloader
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
try:
from .xml_data_parser import GeweDataParser
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
class SimpleGewechatClient:
"""针对 Gewechat 的简单实现。
@author: Soulter
@website: https://github.com/Soulter
"""
def __init__(
self,
base_url: str,
nickname: str,
host: str,
port: int,
event_queue: asyncio.Queue,
):
self.base_url = base_url
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
self.token = None
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.server = quart.Quart(__name__)
self.server.add_url_rule(
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
)
self.server.add_url_rule(
"/astrbot-gewechat/file/<file_token>",
view_func=self._handle_file,
methods=["GET"],
)
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
self.userrealnames = {}
self.shutdown_event = asyncio.Event()
self.staged_files = {}
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
self.lock = asyncio.Lock()
async def get_token_id(self):
"""获取 Gewechat Token。"""
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json()
self.token = json_blob["data"]
logger.info(f"获取到 Gewechat Token: {self.token}")
self.headers = {"X-GEWE-TOKEN": self.token}
async def _convert(self, data: dict) -> AstrBotMessage:
if "TypeName" in data:
type_name = data["TypeName"]
elif "type_name" in data:
type_name = data["type_name"]
else:
raise Exception("无法识别的消息类型")
# 以下没有业务处理,只是避免控制台打印太多的日志
if type_name == "ModContacts":
logger.info("gewechat下发ModContacts消息通知。")
return
if type_name == "DelContacts":
logger.info("gewechat下发DelContacts消息通知。")
return
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
d = None
if "Data" in data:
d = data["Data"]
elif "data" in data:
d = data["data"]
if not d:
logger.warning(f"消息不含 data 字段: {data}")
return
if "CreateTime" in d:
# 得到系统 UTF+8 的 ts
tz_offset = datetime.timedelta(hours=8)
tz = datetime.timezone(tz_offset)
ts = datetime.datetime.now(tz).timestamp()
create_time = d["CreateTime"]
if create_time < ts - 30:
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
return
abm = AstrBotMessage()
from_user_name = d["FromUserName"]["string"] # 消息来源
d["to_wxid"] = from_user_name # 用于发信息
abm.message_id = str(d.get("MsgId"))
abm.session_id = from_user_name
abm.self_id = data["Wxid"] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d["Content"]["string"] # 消息内容
at_me = False
at_wxids = []
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(":\n")
user_id = _t[0]
content = _t[1]
# at
msg_source = d["MsgSource"]
if "\u2005" in content:
# at
# content = content.split('\u2005')[1]
content = re.sub(r"@[^\u2005]*\u2005", "", content)
at_wxids = re.findall(
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
msg_source,
)
abm.group_id = from_user_name
if (
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
):
at_me = True
if "在群聊中@了你" in d.get("PushContent", ""):
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
# 检查消息是否由自己发送,若是则忽略
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
# if user_id == abm.self_id:
# logger.info("忽略自己发送的消息")
# return None
abm.message = []
# 解析用户真实名字
user_real_name = "unknown"
if abm.group_id:
if (
abm.group_id not in self.userrealnames
or user_id not in self.userrealnames[abm.group_id]
):
# 获取群成员列表,并且缓存
if abm.group_id not in self.userrealnames:
self.userrealnames[abm.group_id] = {}
member_list = await self.get_chatroom_member_list(abm.group_id)
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
if member_list and "memberList" in member_list:
for member in member_list["memberList"]:
self.userrealnames[abm.group_id][member["wxid"]] = member[
"nickName"
]
if user_id in self.userrealnames[abm.group_id]:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
try:
info = (await self.get_user_or_group_info(user_id))["data"][0]
user_real_name = info["nickName"]
except Exception as e:
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
user_real_name = user_id
if at_me:
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
for wxid in at_wxids:
# 群聊里 At 其他人的列表
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
abm.message.append(At(qq=wxid, name=_username))
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
if user_id == "weixin":
# 忽略微信团队消息
return
# 不同消息类型
match d["MsgType"]:
case 1:
# 文本消息
abm.message.append(Plain(content))
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid, content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir, f"gewe_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
case 37: # 好友申请
logger.info("消息类型(37):好友申请")
case 42: # 名片
logger.info("消息类型(42):名片")
case 43: # 视频
video = Video(file="", cover=content)
abm.message.append(video)
case 47: # emoji
data_parser = GeweDataParser(content, abm.group_id == "")
emoji = data_parser.parse_emoji()
abm.message.append(emoji)
case 48: # 地理位置
logger.info("消息类型(48):地理位置")
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
data_parser = GeweDataParser(content, abm.group_id == "")
segments = data_parser.parse_mutil_49()
if segments:
abm.message.extend(segments)
for seg in segments:
if isinstance(seg, Plain):
abm.message_str += seg.text
case 51: # 帐号消息同步?
logger.info("消息类型(51):帐号消息同步?")
case 10000: # 被踢出群聊/更换群主/修改群名称
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
logger.info(
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
)
case _:
logger.info(f"未实现的消息类型: {d['MsgType']}")
abm.raw_message = d
logger.debug(f"abm: {abm}")
return abm
async def _callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
if data.get("testMsg", None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}"
)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def _register_file(self, file_path: str) -> str:
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
Args:
file_path (str): 文件路径。
Returns:
str: 返回一个 auth_token文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
"""
async with self.lock:
if not os.path.exists(file_path):
raise Exception(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
self.staged_files[file_token] = file_path
return file_token
async def _handle_file(self, file_token):
async with self.lock:
if file_token not in self.staged_files:
logger.warning(f"请求的文件 {file_token} 不存在。")
return quart.abort(404)
if not os.path.exists(self.staged_files[file_token]):
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
return quart.abort(404)
file_path = self.staged_files[file_token]
self.staged_files.pop(file_token, None)
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={"token": self.token, "callbackUrl": self.callback_url},
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob["ret"] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
)
async def start_polling(self):
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host="0.0.0.0",
port=self.port,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger(self):
await self.shutdown_event.wait()
async def check_online(self, appid: str):
"""检查 APPID 对应的设备是否在线。"""
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkOnline",
headers=self.headers,
json={"appId": appid},
) as resp:
json_blob = await resp.json()
return json_blob["data"]
async def logout(self):
"""登出 gewechat。"""
if self.appid:
online = await self.check_online(self.appid)
if online:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/logout",
headers=self.headers,
json={"appId": self.appid},
) as resp:
json_blob = await resp.json()
logger.info(f"登出结果: {json_blob}")
async def login(self):
"""登录 gewechat。一般来说插件用不到这个方法。"""
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(
self.base_url, self.download_base_url, self.token
)
if self.appid:
try:
online = await self.check_online(self.appid)
if online:
logger.info(f"APPID: {self.appid} 已在线")
return
except Exception as e:
logger.error(f"检查在线状态失败: {e}")
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
payload = {"appId": self.appid}
if self.appid:
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/getLoginQrCode",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
if json_blob["ret"] != 200:
error_msg = json_blob.get("data", {}).get("msg", "")
if "设备不存在" in error_msg:
logger.error(
f"检测到无效的appid: {self.appid},将清除并重新登录。"
)
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
return await self.login()
else:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob["data"]["qrData"]
qr_uuid = json_blob["data"]["uuid"]
appid = json_blob["data"]["appId"]
logger.info(f"APPID: {appid}")
logger.warning(
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
)
except Exception as e:
raise e
# 执行登录
retry_cnt = 64
payload.update({"uuid": qr_uuid, "appId": appid})
while retry_cnt > 0:
retry_cnt -= 1
# 需要验证码
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
code_file_path = os.path.join(temp_dir, "gewe_code")
if os.path.exists(code_file_path):
with open(code_file_path, "r") as f:
code = f.read().strip()
if not code:
logger.warning(
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
)
await asyncio.sleep(5)
continue
payload["captchCode"] = code
logger.info(f"使用验证码: {code}")
try:
os.remove(code_file_path)
except Exception:
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkLogin",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(f"检查登录状态: {json_blob}")
ret = json_blob["ret"]
msg = ""
if json_blob["data"] and "msg" in json_blob["data"]:
msg = json_blob["data"]["msg"]
if ret == 500 and "安全验证码" in msg:
logger.warning(
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
)
else:
if "status" in json_blob["data"]:
status = json_blob["data"]["status"]
nickname = json_blob["data"].get("nickName", "")
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{self.nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
"""
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
"""获取群成员列表。
Args:
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
Returns:
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
"""
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
return json_blob["data"]
async def post_text(self, to_wxid, content: str, ats: str = ""):
"""发送纯文本消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"content": content,
}
if ats:
payload["ats"] = ats
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postText", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
"""发送图片消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
"""发送emoji消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"emojiMd5": emoji_md5,
"emojiSize": emoji_size,
}
# 优先表情包若拿不到表情包的md5就用当作图片发
try:
if emoji_md5 != "" and emoji_size != "":
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postEmoji",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
)
else:
await self.post_image(to_wxid, cdnurl)
except Exception as e:
logger.error(e)
async def post_video(
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"videoUrl": video_url,
"thumbUrl": thumb_url,
"videoDuration": video_duration,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送视频结果: {json_blob}")
async def forward_video(self, to_wxid, cnd_xml: str):
"""转发视频
Args:
to_wxid (str): 发送给谁
cnd_xml (str): 视频消息的cdn信息
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"xml": cnd_xml,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/forwardVideo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"转发视频结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
"""发送语音信息
Args:
voice_url (str): 语音文件的网络链接
voice_duration (int): 语音时长,毫秒
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration,
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
async def post_file(self, to_wxid, file_url: str, file_name: str):
"""发送文件
Args:
to_wxid (string): 微信ID
file_url (str): 文件的网络链接
file_name (str): 文件名
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"fileUrl": file_url,
"fileName": file_name,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送文件结果: {json_blob}")
async def add_friend(self, v3: str, v4: str, content: str):
"""申请添加好友"""
payload = {
"appId": self.appid,
"scene": 3,
"content": content,
"v4": v4,
"v3": v3,
"option": 2,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/addContacts",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"申请添加好友结果: {json_blob}")
return json_blob
async def get_group(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_group_member(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def accept_group_invite(self, url: str):
"""同意进群"""
payload = {"appId": self.appid, "url": url}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/agreeJoinRoom",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def add_group_member_to_friend(
self, group_id: str, to_wxid: str, content: str
):
payload = {
"appId": self.appid,
"chatroomId": group_id,
"content": content,
"memberWxid": to_wxid,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/addGroupMemberAsFriend",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_user_or_group_info(self, *ids):
"""
获取用户或群组信息。
:param ids: 可变数量的 wxid 参数
"""
wxids_str = list(ids)
payload = {
"appId": self.appid,
"wxids": wxids_str, # 使用逗号分隔的字符串
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/getDetailInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_contacts_list(self):
"""
获取通讯录列表
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
"""
payload = {"appId": self.appid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/fetchContactsList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取通讯录列表结果: {json_blob}")
return json_blob

View File

@@ -1,55 +0,0 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader:
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}", headers=self.headers, json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
"""返回一个可下载的 URL"""
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {"appId": appid, "xml": xml, "type": choice}
data = await self._post_json(
self.base_url, "/message/downloadImage", payload
)
json_blob = json.loads(data)
if "fileUrl" in json_blob["data"]:
return self.download_base_url + json_blob["data"]["fileUrl"]
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")
async def download_emoji_md5(self, app_id, emoji_md5):
"""下载emoji"""
try:
payload = {"appId": app_id, "emojiMd5": emoji_md5}
# gewe 计划中的接口暂时没有实现。返回代码404
data = await self._post_json(
self.base_url, "/message/downloadEmojiMd5", payload
)
json_blob = json.loads(data)
return json_blob
except BaseException as e:
logger.error(f"gewe download emoji: {e}")

View File

@@ -1,264 +0,0 @@
import asyncio
import re
import wave
import uuid
import traceback
import os
from typing import AsyncGenerator
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
from astrbot.api.message_components import (
Plain,
Image,
Record,
At,
File,
Video,
WechatEmoji as Emoji,
)
from .client import SimpleGewechatClient
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
def get_wav_duration(file_path):
with wave.open(file_path, "rb") as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
elif n_frames == 0:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleGewechatClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
):
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
# 检查@
ats = []
ats_names = []
for comp in message.chain:
if isinstance(comp, At):
ats.append(comp.qq)
ats_names.append(comp.name)
has_at = False
for comp in message.chain:
if isinstance(comp, Plain):
text = comp.text
payload = {
"to_wxid": to_wxid,
"content": text,
}
if not has_at and ats:
ats = f"{','.join(ats)}"
ats_names = f"@{' @'.join(ats_names)}"
text = f"{ats_names} {text}"
payload["content"] = text
payload["ats"] = ats
has_at = True
await client.post_text(**payload)
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
token = await client._register_file(img_path)
img_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback img url: {img_url}")
await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
if comp.cover != "":
await client.forward_video(to_wxid, comp.cover)
else:
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
video_url = comp.file
# 根据 url 下载视频
if video_url.startswith("http"):
video_filename = f"{uuid.uuid4()}.mp4"
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
video_path = os.path.join(temp_dir, video_filename)
await download_file(video_url, video_path)
else:
video_path = video_url
video_token = await client._register_file(video_path)
video_callback_url = f"{client.file_server_url}/{video_token}"
# 获取视频第一帧
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
thumb_path = os.path.join(
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
)
video_path = video_path.replace(" ", "\\ ")
try:
ff = FFmpeg()
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
ff.options(command)
thumb_token = await client._register_file(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_token}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
# 发送视频
await client.post_video(
to_wxid, video_callback_url, thumb_url, video_duration
)
# 删除临时缩略图文件
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = await comp.convert_to_file_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
try:
duration = await wav_to_tencent_silk(record_path, silk_path)
except Exception as e:
logger.error(traceback.format_exc())
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
logger.info("Silk 语音文件格式转换至: " + record_path)
if duration == 0:
duration = get_wav_duration(record_path)
token = await client._register_file(silk_path)
record_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback record url: {record_url}")
await client.post_voice(to_wxid, record_url, duration * 1000)
elif isinstance(comp, File):
file_path = comp.file
file_name = comp.name
if file_path.startswith("file:///"):
file_path = file_path[8:]
elif file_path.startswith("http"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
temp_file_path = os.path.join(temp_dir, file_name)
await download_file(file_path, temp_file_path)
file_path = temp_file_path
else:
file_path = file_path
token = await client._register_file(file_path)
file_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback file url: {file_url}")
await client.post_file(to_wxid, file_url, file_name)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At):
pass
else:
logger.debug(f"gewechat 忽略: {comp.type}")
async def send(self, message: MessageChain):
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
await super().send(message)
async def get_group(self, group_id=None, **kwargs):
# 确定有效的 group_id
if group_id is None:
group_id = self.get_group_id()
if not group_id:
return None
res = await self.client.get_group(group_id)
data: dict = res["data"]
if not data["chatroomId"]:
return None
members = [
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
for member in data.get("memberList", [])
]
return Group(
group_id=data["chatroomId"],
group_name=data.get("nickName"),
group_avatar=data.get("smallHeadImgUrl"),
group_owner=data.get("chatRoomOwner"),
members=members,
)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -1,103 +0,0 @@
import sys
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
class GewechatPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
self.client = None
self.client = SimpleGewechatClient(
self.config["base_url"],
self.config["nickname"],
self.config["host"],
self.config["port"],
self._event_queue,
)
async def on_event_received(abm: AstrBotMessage):
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
@override
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
session_id = session.session_id
if "#" in session_id:
# unique session
to_wxid = session_id.split("#")[1]
else:
to_wxid = session_id
await GewechatPlatformEvent.send_with_client(
message_chain, to_wxid, self.client
)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="gewechat",
description="基于 gewechat 的 Wechat 适配器",
id=self.config.get("id"),
)
async def terminate(self):
self.client.shutdown_event.set()
try:
await self.client.server.shutdown()
except Exception as _:
pass
logger.info("Gewechat 适配器已被优雅地关闭。")
async def logout(self):
await self.client.logout()
@override
def run(self):
return self._run()
async def _run(self):
await self.client.login()
await self.client.start_polling()
async def handle_msg(self, message: AstrBotMessage):
if message.type == MessageType.GROUP_MESSAGE:
if self.settingss["unique_session"]:
message.session_id = message.sender.user_id + "#" + message.group_id
message_event = GewechatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
)
self.commit_event(message_event)
def get_client(self) -> SimpleGewechatClient:
return self.client

View File

@@ -1,110 +0,0 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import (
WechatEmoji as Emoji,
Reply,
Plain,
BaseMessageComponent,
)
class GeweDataParser:
def __init__(self, data, is_private_chat):
self.data = data
self.is_private_chat = is_private_chat
def _format_to_xml(self):
return eT.fromstring(self.data)
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
appmsg_type = self._format_to_xml().find(".//appmsg/type")
if appmsg_type is None:
return
match appmsg_type.text:
case "57":
return self.parse_reply()
def parse_emoji(self) -> Emoji | None:
try:
emoji_element = self._format_to_xml().find(".//emoji")
# 提取 md5 和 len 属性
if emoji_element is not None:
md5_value = emoji_element.get("md5")
emoji_size = emoji_element.get("len")
cdnurl = emoji_element.get("cdnurl")
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
except Exception as e:
logger.error(f"gewechat: parse_emoji failed, {e}")
def parse_reply(self) -> list[Reply, Plain] | None:
"""解析引用消息
Returns:
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
"""
try:
replied_id = -1
replied_uid = 0
replied_nickname = ""
replied_content = "" # 被引用者说的内容
content = "" # 引用者说的内容
root = self._format_to_xml()
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 被引用的信息
svrid = refermsg.find("svrid")
fromusr = refermsg.find("fromusr")
displayname = refermsg.find("displayname")
refermsg_content = refermsg.find("content")
if svrid is not None:
replied_id = svrid.text
if fromusr is not None:
replied_uid = fromusr.text
if displayname is not None:
replied_nickname = displayname.text
if refermsg_content is not None:
# 处理引用嵌套,包括嵌套公众号消息
if refermsg_content.text.startswith(
"<msg>"
) or refermsg_content.text.startswith("<?xml"):
try:
logger.debug("gewechat: Reference message is nested")
refer_root = eT.fromstring(refermsg_content.text)
img = refer_root.find("img")
if img is not None:
replied_content = "[图片]"
else:
app_msg = refer_root.find("appmsg")
refermsg_content_title = app_msg.find("title")
logger.debug(
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
)
replied_content = refermsg_content_title.text
except Exception as e:
logger.error(f"gewechat: nested failed, {e}")
# 处理异常情况
replied_content = refermsg_content.text
else:
replied_content = refermsg_content.text
# 提取引用者说的内容
title = root.find(".//appmsg/title")
if title is not None:
content = title.text
reply_seg = Reply(
id=replied_id,
chain=[Plain(replied_content)],
sender_id=replied_uid,
sender_nickname=replied_nickname,
message_str=replied_content,
)
plain_seg = Plain(content)
return [reply_seg, plain_seg]
except Exception as e:
logger.error(f"gewechat: parse_reply failed, {e}")

View File

@@ -28,10 +28,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
self.send_buffer = None self.send_buffer = None
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
if not self.send_buffer: self.send_buffer = message
self.send_buffer = message await self._post_send()
else:
self.send_buffer.chain.extend(message.chain)
async def send_streaming(self, generator, use_fallback: bool = False): async def send_streaming(self, generator, use_fallback: bool = False):
"""流式输出仅支持消息列表私聊""" """流式输出仅支持消息列表私聊"""

View File

@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
base64_content = base64.b64encode(content).decode("utf-8") base64_content = base64.b64encode(content).decode("utf-8")
return base64_content return base64_content
else: else:
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}") logger.error(
f"Failed to download slack file: {resp.status} {await resp.text()}"
)
raise Exception(f"下载文件失败: {resp.status}") raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]: async def run(self) -> Awaitable[Any]:

View File

@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
"text": {"type": "mrkdwn", "text": "文件上传失败"}, "text": {"type": "mrkdwn", "text": "文件上传失败"},
} }
file_url = response["files"][0]["permalink"] file_url = response["files"][0]["permalink"]
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}} return {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
else: else:
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}} return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}

View File

@@ -40,20 +40,21 @@ class TelegramPlatformEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id) super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client self.client = client
def _split_message(self, text: str) -> list[str]: @classmethod
if len(text) <= self.MAX_MESSAGE_LENGTH: def _split_message(cls, text: str) -> list[str]:
if len(text) <= cls.MAX_MESSAGE_LENGTH:
return [text] return [text]
chunks = [] chunks = []
while text: while text:
if len(text) <= self.MAX_MESSAGE_LENGTH: if len(text) <= cls.MAX_MESSAGE_LENGTH:
chunks.append(text) chunks.append(text)
break break
split_point = self.MAX_MESSAGE_LENGTH split_point = cls.MAX_MESSAGE_LENGTH
segment = text[: self.MAX_MESSAGE_LENGTH] segment = text[: cls.MAX_MESSAGE_LENGTH]
for _, pattern in self.SPLIT_PATTERNS.items(): for _, pattern in cls.SPLIT_PATTERNS.items():
if matches := list(pattern.finditer(segment)): if matches := list(pattern.finditer(segment)):
last_match = matches[-1] last_match = matches[-1]
split_point = last_match.end() split_point = last_match.end()
@@ -64,8 +65,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
return chunks return chunks
@classmethod
async def send_with_client( async def send_with_client(
self, client: ExtBot, message: MessageChain, user_name: str cls, client: ExtBot, message: MessageChain, user_name: str
): ):
image_path = None image_path = None
@@ -97,7 +99,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
if at_user_id and not at_flag: if at_user_id and not at_flag:
i.text = f"@{at_user_id} {i.text}" i.text = f"@{at_user_id} {i.text}"
at_flag = True at_flag = True
chunks = self._split_message(i.text) chunks = cls._split_message(i.text)
for chunk in chunks: for chunk in chunks:
try: try:
md_text = telegramify_markdown.markdownify( md_text = telegramify_markdown.markdownify(
@@ -158,6 +160,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
async for chain in generator: async for chain in generator:
if isinstance(chain, MessageChain): if isinstance(chain, MessageChain):
if chain.type == "break":
# 分割符
message_id = None # 重置消息 ID
delta = "" # 重置 delta
continue
# 处理消息链中的每个组件 # 处理消息链中的每个组件
for i in chain.chain: for i in chain.chain:
if isinstance(i, Plain): if isinstance(i, Plain):

View File

@@ -2,7 +2,7 @@ import time
import asyncio import asyncio
import uuid import uuid
import os import os
from typing import Awaitable, Any from typing import Awaitable, Any, Callable
from astrbot.core.platform import ( from astrbot.core.platform import (
Platform, Platform,
AstrBotMessage, AstrBotMessage,
@@ -13,7 +13,7 @@ from astrbot.core.platform import (
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.message.components import Plain, Image, Record # noqa: F403 from astrbot.core.message.components import Plain, Image, Record # noqa: F403
from astrbot import logger from astrbot import logger
from astrbot.core import web_chat_queue from .webchat_queue_mgr import webchat_queue_mgr, WebChatQueueMgr
from .webchat_event import WebChatMessageEvent from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter from ...register import register_platform_adapter
@@ -21,14 +21,46 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class QueueListener: class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None: def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
self.queue = queue self.webchat_queue_mgr = webchat_queue_mgr
self.callback = callback self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, conversation_id: str):
"""Listen to a specific conversation queue"""
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(
f"Error processing message from conversation {conversation_id}: {e}"
)
break
async def run(self): async def run(self):
"""Monitor for new conversation queues and start listeners"""
monitored_conversations = set()
while True: while True:
data = await self.queue.get() # Check for new conversations
await self.callback(data) current_conversations = set(self.webchat_queue_mgr.queues.keys())
new_conversations = current_conversations - monitored_conversations
# Start listeners for new conversations
for conversation_id in new_conversations:
task = asyncio.create_task(self.listen_to_queue(conversation_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_conversations.add(conversation_id)
logger.debug(f"Started listener for conversation: {conversation_id}")
# Clean up monitored conversations that no longer exist
removed_conversations = monitored_conversations - current_conversations
monitored_conversations -= removed_conversations
await asyncio.sleep(1) # Check for new conversations every second
@register_platform_adapter("webchat", "webchat") @register_platform_adapter("webchat", "webchat")
@@ -45,7 +77,7 @@ class WebChatAdapter(Platform):
os.makedirs(self.imgs_dir, exist_ok=True) os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="webchat", description="webchat", id=self.config.get("id") name="webchat", description="webchat", id=self.config.get("id", "")
) )
async def send_by_session( async def send_by_session(
@@ -105,7 +137,7 @@ class WebChatAdapter(Platform):
abm = await self.convert_message(data) abm = await self.convert_message(data)
await self.handle_msg(abm) await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback) bot = QueueListener(webchat_queue_mgr, callback)
return bot.run() return bot.run()
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
@@ -119,6 +151,10 @@ class WebChatAdapter(Platform):
session_id=message.session_id, session_id=message.session_id,
) )
_, _, payload = message.raw_message # type: ignore
message_event.set_extra("selected_provider", payload.get("selected_provider"))
message_event.set_extra("selected_model", payload.get("selected_model"))
self.commit_event(message_event) self.commit_event(message_event)
async def terminate(self): async def terminate(self):

View File

@@ -5,8 +5,8 @@ from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from astrbot.core import web_chat_back_queue
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
@@ -18,13 +18,18 @@ class WebChatMessageEvent(AstrMessageEvent):
@staticmethod @staticmethod
async def _send(message: MessageChain, session_id: str, streaming: bool = False): async def _send(message: MessageChain, session_id: str, streaming: bool = False):
cid = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
if not message: if not message:
await web_chat_back_queue.put( await web_chat_back_queue.put(
{"type": "end", "data": "", "streaming": False} {
"type": "end",
"data": "",
"streaming": False,
} # end means this request is finished
) )
return "" return ""
cid = session_id.split("!")[-1]
data = "" data = ""
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
@@ -35,6 +40,7 @@ class WebChatMessageEvent(AstrMessageEvent):
"cid": cid, "cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
"chain_type": message.type,
} }
) )
elif isinstance(comp, Image): elif isinstance(comp, Image):
@@ -97,29 +103,35 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id) await WebChatMessageEvent._send(message, session_id=self.session_id)
await web_chat_back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
"cid": self.session_id.split("!")[-1],
}
)
await super().send(message) await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False): async def send_streaming(self, generator, use_fallback: bool = False):
final_data = "" final_data = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator: async for chain in generator:
if chain.type == "break" and final_data:
# 分割符
await web_chat_back_queue.put(
{
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"cid": cid,
}
)
final_data = ""
continue
final_data += await WebChatMessageEvent._send( final_data += await WebChatMessageEvent._send(
chain, session_id=self.session_id, streaming=True chain, session_id=self.session_id, streaming=True
) )
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "end", "type": "complete", # complete means we return the final result
"data": final_data, "data": final_data,
"streaming": True, "streaming": True,
"cid": self.session_id.split("!")[-1], "cid": cid,
} }
) )
await super().send_streaming(generator, use_fallback) await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,35 @@
import asyncio
class WebChatQueueMgr:
def __init__(self) -> None:
self.queues = {}
"""Conversation ID to asyncio.Queue mapping"""
self.back_queues = {}
"""Conversation ID to asyncio.Queue mapping for responses"""
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a queue for the given conversation ID"""
if conversation_id not in self.queues:
self.queues[conversation_id] = asyncio.Queue()
return self.queues[conversation_id]
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a back queue for the given conversation ID"""
if conversation_id not in self.back_queues:
self.back_queues[conversation_id] = asyncio.Queue()
return self.back_queues[conversation_id]
def remove_queues(self, conversation_id: str):
"""Remove queues for the given conversation ID"""
if conversation_id in self.queues:
del self.queues[conversation_id]
if conversation_id in self.back_queues:
del self.back_queues[conversation_id]
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
webchat_queue_mgr = WebChatQueueMgr()

View File

@@ -210,6 +210,16 @@ class WeChatPadProAdapter(Platform):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
def _extract_auth_key(self, data):
"""Helper method to extract auth_key from response data."""
if isinstance(data, dict):
auth_keys = data.get("authKeys") # 新接口
if isinstance(auth_keys, list) and auth_keys:
return auth_keys[0]
elif isinstance(data, list) and data: # 旧接口
return data[0]
return None
async def generate_auth_key(self): async def generate_auth_key(self):
""" """
生成授权码。 生成授权码。
@@ -218,28 +228,30 @@ class WeChatPadProAdapter(Platform):
params = {"key": self.admin_key} params = {"key": self.admin_key}
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码 payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
self.auth_key = None # Reset auth_key before generating a new one
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
try: try:
async with session.post(url, params=params, json=payload) as response: async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(
f"生成授权码失败: {response.status}, {await response.text()}"
)
return
response_data = await response.json() response_data = await response.json()
# 修正成功判断条件和授权码提取路径 if response_data.get("Code") == 200:
if response.status == 200 and response_data.get("Code") == 200: if data := response_data.get("Data"):
# 授权码在 Data 字段的列表中 self.auth_key = self._extract_auth_key(data)
if (
response_data.get("Data") if self.auth_key:
and isinstance(response_data["Data"], list) logger.info("成功获取授权码")
and len(response_data["Data"]) > 0
):
self.auth_key = response_data["Data"][0]
logger.info(f"成功获取授权码 {self.auth_key[:8]}...")
else: else:
logger.error( logger.error(
f"生成授权码成功但未找到授权码: {response_data}" f"生成授权码成功但未找到授权码: {response_data}"
) )
else: else:
logger.error( logger.error(f"生成授权码失败: {response_data}")
f"生成授权码失败: {response.status}, {response_data}"
)
except aiohttp.ClientConnectorError as e: except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}") logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
except Exception as e: except Exception as e:

View File

@@ -17,7 +17,7 @@ from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64 from astrbot.core.utils.tencent_record_helper import audio_to_tencent_silk_base64
if TYPE_CHECKING: if TYPE_CHECKING:
from .wechatpadpro_adapter import WeChatPadProAdapter from .wechatpadpro_adapter import WeChatPadProAdapter
@@ -113,7 +113,7 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record): async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
record_path = await comp.convert_to_file_path() record_path = await comp.convert_to_file_path()
# 默认已经存在 data/temp 中 # 默认已经存在 data/temp 中
b64, duration = await wav_to_tencent_silk_base64(record_path) b64, duration = await audio_to_tencent_silk_base64(record_path)
payload = { payload = {
"ToUserName": self.session_id, "ToUserName": self.session_id,
"VoiceData": b64, "VoiceData": b64,

View File

@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。 注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。
:return: 接口调用结果 :return: 接口调用结果
""" """
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid} data = {
"token": token,
"cursor": cursor,
"limit": limit,
"open_kfid": open_kfid,
}
return self._post("kf/sync_msg", data=data) return self._post("kf/sync_msg", data=data)
def get_service_state(self, open_kfid, external_userid): def get_service_state(self, open_kfid, external_userid):
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
} }
return self._post("kf/service_state/get", data=data) return self._post("kf/service_state/get", data=data)
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""): def trans_service_state(
self, open_kfid, external_userid, service_state, servicer_userid=""
):
""" """
变更会话状态 变更会话状态
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
""" """
return self._get("kf/customer/get_upgrade_service_config") return self._get("kf/customer/get_upgrade_service_config")
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None): def upgrade_service(
self, open_kfid, external_userid, service_type, member=None, groupchat=None
):
""" """
为客户升级为专员或客户群服务 为客户升级为专员或客户群服务
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time} data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
return self._post("kf/get_corp_statistic", data=data) return self._post("kf/get_corp_statistic", data=data)
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None): def get_servicer_statistic(
self, start_time, end_time, open_kfid=None, servicer_userid=None
):
""" """
获取「客户数据统计」接待人员明细数据 获取「客户数据统计」接待人员明细数据

View File

@@ -26,6 +26,7 @@ from optionaldict import optionaldict
from wechatpy.client.api.base import BaseWeChatAPI from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKFMessage(BaseWeChatAPI): class WeChatKFMessage(BaseWeChatAPI):
""" """
发送微信客服消息 发送微信客服消息
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
msg={"msgtype": "news", "link": {"link": articles_data}}, msg={"msgtype": "news", "link": {"link": articles_data}},
) )
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""): def send_msgmenu(
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
):
return self.send( return self.send(
user_id, user_id,
open_kfid, open_kfid,
msgid, msgid,
msg={ msg={
"msgtype": "msgmenu", "msgtype": "msgmenu",
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content}, "msgmenu": {
"head_content": head_content,
"list": menu_list,
"tail_content": tail_content,
},
}, },
) )
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""): def send_location(
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
):
return self.send( return self.send(
user_id, user_id,
open_kfid, open_kfid,
msgid, msgid,
msg={ msg={
"msgtype": "location", "msgtype": "location",
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude}, "msgmenu": {
"name": name,
"address": address,
"latitude": latitude,
"longitude": longitude,
},
}, },
) )
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""): def send_miniprogram(
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
):
return self.send( return self.send(
user_id, user_id,
open_kfid, open_kfid,
msgid, msgid,
msg={ msg={
"msgtype": "miniprogram", "msgtype": "miniprogram",
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath}, "msgmenu": {
"appid": appid,
"title": title,
"thumb_media_id": thumb_media_id,
"pagepath": pagepath,
},
}, },
) )

View File

@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.wexin_event_workers[msg.id] = future self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future) await self.convert_message(msg, future)
# I love shield so much! # I love shield so much!
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s result = await asyncio.wait_for(
asyncio.shield(future), 60
) # wait for 60s
logger.debug(f"Got future result: {result}") logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None) self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py return result # xml. see weixin_offacc_event.py

View File

@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
return return
logger.info(f"微信公众平台上传语音返回: {response}") logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode: if active_send_mode:
self.client.message.send_voice( self.client.message.send_voice(
message_obj.sender.user_id, message_obj.sender.user_id,

View File

@@ -58,7 +58,7 @@ class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
role: str = "assistant" role: str = "assistant"
def to_dict(self): def to_dict(self):
@@ -67,7 +67,7 @@ class AssistantMessageSegment:
} }
if self.content: if self.content:
ret["content"] = self.content ret["content"] = self.content
elif self.tool_calls: if self.tool_calls:
ret["tool_calls"] = self.tool_calls ret["tool_calls"] = self.tool_calls
return ret return ret
@@ -95,27 +95,38 @@ class ProviderRequest:
"""提示词""" """提示词"""
session_id: str = "" session_id: str = ""
"""会话 ID""" """会话 ID"""
image_urls: List[str] = None image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表""" """图片 URL 列表"""
func_tool: FuncCall = None func_tool: FuncCall | None = None
"""可用的函数工具""" """可用的函数工具"""
contexts: List = None contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致: """上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
""" """
system_prompt: str = "" system_prompt: str = ""
"""系统提示词""" """系统提示词"""
conversation: Conversation = None conversation: Conversation | None = None
tool_calls_result: ToolCallsResult = None tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
model: str | None = None
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self): def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})" return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()
def append_tool_calls_result(self, tool_calls_result: ToolCallsResult):
"""添加工具调用结果到请求中"""
if not self.tool_calls_result:
self.tool_calls_result = []
if isinstance(self.tool_calls_result, ToolCallsResult):
self.tool_calls_result = [self.tool_calls_result]
self.tool_calls_result.append(tool_calls_result)
def _print_friendly_context(self): def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>""" """打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts: if not self.contexts:

View File

@@ -39,6 +39,72 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型 ] # json schema 支持的数据类型
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
@dataclass @dataclass
class FuncTool: class FuncTool:
""" """
@@ -80,12 +146,10 @@ class FuncTool:
if not self.mcp_client or not self.mcp_client.session: if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available") raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name # 使用name属性而不是额外的mcp_tool_name
if ":" in self.name: actual_tool_name = (
# 如果名字是格式为 mcp:server:tool_name提取实际的工具名 self.name.split(":")[-1] if ":" in self.name else self.name
actual_tool_name = self.name.split(":")[-1] )
return await self.mcp_client.session.call_tool(actual_tool_name, args) return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
return await self.mcp_client.session.call_tool(self.name, args)
else: else:
raise Exception(f"Unknown function origin: {self.origin}") raise Exception(f"Unknown function origin: {self.origin}")
@@ -100,6 +164,7 @@ class MCPClient:
self.active: bool = True self.active: bool = True
self.tools: List[mcp.Tool] = [] self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = [] self.server_errlogs: List[str] = []
self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str): async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器 """连接到 MCP 服务器
@@ -112,17 +177,19 @@ class MCPClient:
Args: Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
""" """
cfg = mcp_server_config.copy() cfg = _prepare_config(mcp_server_config.copy())
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
key_0 = list(cfg["mcpServers"].keys())[0] def logging_callback(msg: str):
cfg = cfg["mcpServers"][key_0] # 处理 MCP 服务的错误日志
cfg.pop("active", None) # Remove active flag from config print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if "url" in cfg: if "url" in cfg:
is_sse = True success, error_msg = await _quick_test_mcp_connection(cfg)
if cfg.get("transport") == "streamable_http": if not success:
is_sse = False raise Exception(error_msg)
if is_sse:
if cfg.get("transport") != "streamable_http":
# SSE transport method # SSE transport method
self._streams_context = sse_client( self._streams_context = sse_client(
url=cfg["url"], url=cfg["url"],
@@ -130,11 +197,18 @@ class MCPClient:
timeout=cfg.get("timeout", 5), timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
) )
streams = await self._streams_context.__aenter__() streams = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session # Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context( self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*streams) mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
) )
else: else:
timeout = timedelta(seconds=cfg.get("timeout", 30)) timeout = timedelta(seconds=cfg.get("timeout", 30))
@@ -148,11 +222,19 @@ class MCPClient:
sse_read_timeout=sse_read_timeout, sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True), terminate_on_close=cfg.get("terminate_on_close", True),
) )
read_s, write_s, _ = await self._streams_context.__aenter__() read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session # Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context( self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(read_stream=read_s, write_stream=write_s) mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
) )
else: else:
@@ -172,7 +254,7 @@ class MCPClient:
logger=logger, logger=logger,
identifier=f"MCPServer-{name}", identifier=f"MCPServer-{name}",
callback=callback, callback=callback,
), ), # type: ignore
), ),
) )
@@ -180,19 +262,18 @@ class MCPClient:
self.session = await self.exit_stack.enter_async_context( self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport) mcp.ClientSession(*stdio_transport)
) )
await self.session.initialize() await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult: async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools""" """List all tools from the server and save them to self.tools"""
response = await self.session.list_tools() response = await self.session.list_tools()
logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools self.tools = response.tools
return response return response
async def cleanup(self): async def cleanup(self):
"""Clean up resources""" """Clean up resources"""
await self.exit_stack.aclose() await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done
class FuncCall: class FuncCall:
@@ -201,8 +282,6 @@ class FuncCall:
"""内部加载的 func tools""" """内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {} self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表""" """MCP 服务列表"""
self.mcp_service_queue = asyncio.Queue()
"""用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {} self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool: def empty(self) -> bool:
@@ -258,7 +337,7 @@ class FuncCall:
return f return f
return None return None
async def _init_mcp_clients(self) -> None: async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
``` ```
{ {
@@ -300,113 +379,64 @@ class FuncCall:
) )
self.mcp_client_event[name] = event self.mcp_client_event[name] = event
async def mcp_service_selector(self):
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
{"type": "init"} 初始化所有MCP客户端
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
{"type": "terminate"} 终止所有MCP客户端
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
"""
while True:
data = await self.mcp_service_queue.get()
if data["type"] == "init":
if "name" in data:
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(
data["name"], data["cfg"], event
)
)
self.mcp_client_event[data["name"]] = event
else:
await self._init_mcp_clients()
elif data["type"] == "terminate":
if "name" in data:
# await self._terminate_mcp_client(data["name"])
if data["name"] in self.mcp_client_event:
self.mcp_client_event[data["name"]].set()
self.mcp_client_event.pop(data["name"], None)
self.func_list = [
f
for f in self.func_list
if not (
f.origin == "mcp" and f.mcp_server_name == data["name"]
)
]
else:
for name in self.mcp_client_dict.keys():
# await self._terminate_mcp_client(name)
# self.mcp_client_event[name].set()
if name in self.mcp_client_event:
self.mcp_client_event[name].set()
self.mcp_client_event.pop(name, None)
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
async def _init_mcp_client_task_wrapper( async def _init_mcp_client_task_wrapper(
self, name: str, cfg: dict, event: asyncio.Event self,
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
) -> None: ) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常""" """初始化 MCP 客户端的包装函数,用于捕获异常"""
try: try:
await self._init_mcp_client(name, cfg) await self._init_mcp_client(name, cfg)
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
# tell the caller we are ready
ready_future.set_result(tools)
await event.wait() await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号") logger.info(f"收到 MCP 客户端 {name} 终止信号")
await self._terminate_mcp_client(name)
except Exception as e: except Exception as e:
import traceback logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
if ready_future and not ready_future.done():
traceback.print_exc() ready_future.set_exception(e)
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def _init_mcp_client(self, name: str, config: dict) -> None: async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端""" """初始化单个MCP客户端"""
try: # 先清理之前的客户端,如果存在
# 先清理之前的客户端,如果存在 if name in self.mcp_client_dict:
if name in self.mcp_client_dict: await self._terminate_mcp_client(name)
await self._terminate_mcp_client(name)
mcp_client = MCPClient() mcp_client = MCPClient()
mcp_client.name = name mcp_client.name = name
self.mcp_client_dict[name] = mcp_client self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name) await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save() tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools] logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]
# 移除该MCP服务之前的工具如有 # 移除该MCP服务之前的工具如有
self.func_list = [ self.func_list = [
f f
for f in self.func_list for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name) if not (f.origin == "mcp" and f.mcp_server_name == name)
] ]
# 将 MCP 工具转换为 FuncTool 并添加到 func_list # 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools: for tool in mcp_client.tools:
func_tool = FuncTool( func_tool = FuncTool(
name=tool.name, name=tool.name,
parameters=tool.inputSchema, parameters=tool.inputSchema,
description=tool.description, description=tool.description,
origin="mcp", origin="mcp",
mcp_server_name=name, mcp_server_name=name,
mcp_client=mcp_client, mcp_client=mcp_client,
) )
self.func_list.append(func_tool) self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
# 发生错误时确保客户端被清理
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
return
async def _terminate_mcp_client(self, name: str) -> None: async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端""" """关闭并清理MCP客户端"""
@@ -414,9 +444,9 @@ class FuncCall:
try: try:
# 关闭MCP连接 # 关闭MCP连接
await self.mcp_client_dict[name].cleanup() await self.mcp_client_dict[name].cleanup()
del self.mcp_client_dict[name] self.mcp_client_dict.pop(name)
except Exception as e: except Exception as e:
logger.info(f"清空 MCP 客户端资源 {name}: {e}") logger.error(f"清空 MCP 客户端资源 {name}: {e}")
# 移除关联的FuncTool # 移除关联的FuncTool
self.func_list = [ self.func_list = [
f f
@@ -425,6 +455,103 @@ class FuncCall:
] ]
logger.info(f"已关闭 MCP 服务 {name}") logger.info(f"已关闭 MCP 服务 {name}")
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:
if "url" in config:
success, error_msg = await _quick_test_mcp_connection(config)
if not success:
raise Exception(error_msg)
mcp_client = MCPClient()
try:
logger.debug(f"testing MCP server connection with config: {config}")
await mcp_client.connect_to_server(config, "test")
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
finally:
logger.debug("Cleaning up MCP client after testing connection.")
await mcp_client.cleanup()
return tool_names
async def enable_mcp_server(
self,
name: str,
config: dict,
event: asyncio.Event | None = None,
ready_future: asyncio.Future | None = None,
timeout: int = 30,
) -> None:
"""Enable_mcp_server a new MCP server to the manager and initialize it.
Args:
name (str): The name of the MCP server.
config (dict): Configuration for the MCP server.
event (asyncio.Event): Event to signal when the MCP client is ready.
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
timeout (int): Timeout for the initialization.
Raises:
TimeoutError: If the initialization does not complete within the specified timeout.
Exception: If there is an error during initialization.
"""
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
return
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
finally:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self, name: str | None = None, timeout: float = 10
) -> None:
"""Disable an MCP server by its name.
Args:
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
timeout (int): Timeout.
"""
if name:
if name not in self.mcp_client_event:
return
client = self.mcp_client_dict.get(name)
self.mcp_client_event[name].set()
if not client:
return
client_running_event = client.running_event
try:
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
finally:
self.mcp_client_event.pop(name, None)
self.func_list = [
f
for f in self.func_list
if f.origin != "mcp" or f.mcp_server_name != name
]
else:
running_events = [
client.running_event.wait() for client in self.mcp_client_dict.values()
]
for key, event in self.mcp_client_event.items():
event.set()
# waiting for all clients to finish
try:
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
""" """
获得 OpenAI API 风格的**已经激活**的工具描述 获得 OpenAI API 风格的**已经激活**的工具描述
@@ -629,8 +756,3 @@ class FuncCall:
def __repr__(self): def __repr__(self):
return str(self.func_list) return str(self.func_list)
async def terminate(self):
for name in self.mcp_client_dict.keys():
await self._terminate_mcp_client(name)
logger.debug(f"清理 MCP 客户端 {name} 资源")

View File

@@ -1,12 +1,14 @@
import traceback
import asyncio import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig import traceback
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entities import ProviderType
from typing import List from typing import List
from astrbot.core.db import BaseDatabase
from .register import provider_cls_map, llm_tools
from astrbot.core import logger, sp from astrbot.core import logger, sp
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
class ProviderManager: class ProviderManager:
@@ -91,17 +93,17 @@ class ProviderManager:
"""加载的 Speech To Text Provider 的实例""" """加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = [] self.tts_provider_insts: List[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例""" """加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[Provider] = [] self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例""" """加载的 Embedding Provider 的实例"""
self.inst_map = {} self.inst_map: dict[str, Provider] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例""" """Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None self.curr_provider_inst: Provider | None = None
"""默认的 Provider 实例""" """默认的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider = None self.curr_stt_provider_inst: STTProvider | None = None
"""默认的 Speech To Text Provider 实例""" """默认的 Speech To Text Provider 实例"""
self.curr_tts_provider_inst: TTSProvider = None self.curr_tts_provider_inst: TTSProvider | None = None
"""默认的 Text To Speech Provider 实例""" """默认的 Text To Speech Provider 实例"""
self.db_helper = db_helper self.db_helper = db_helper
@@ -145,29 +147,29 @@ class ProviderManager:
await self.load_provider(provider_config) await self.load_provider(provider_config)
# 设置默认提供商 # 设置默认提供商
self.curr_provider_inst = self.inst_map.get( selected_provider_id = sp.get(
self.provider_settings.get("default_provider_id") "curr_provider", self.provider_settings.get("default_provider_id")
) )
selected_stt_provider_id = sp.get(
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
)
selected_tts_provider_id = sp.get(
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
if not self.curr_provider_inst and self.provider_insts: if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0] self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get( self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
self.provider_stt_settings.get("provider_id")
)
if not self.curr_stt_provider_inst and self.stt_provider_insts: if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0] self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get( self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
self.provider_tts_settings.get("provider_id")
)
if not self.curr_tts_provider_inst and self.tts_provider_insts: if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0] self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接 # 初始化 MCP Client 连接
asyncio.create_task( asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
)
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
async def load_provider(self, provider_config: dict): async def load_provider(self, provider_config: dict):
if not provider_config["enable"]: if not provider_config["enable"]:
@@ -190,11 +192,6 @@ class ProviderManager:
from .sources.anthropic_source import ( from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic, ProviderAnthropic as ProviderAnthropic,
) )
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import (
LLMTunerModelLoader as LLMTunerModelLoader,
)
case "dify": case "dify":
from .sources.dify_source import ProviderDify as ProviderDify from .sources.dify_source import ProviderDify as ProviderDify
case "dashscope": case "dashscope":
@@ -253,6 +250,10 @@ class ProviderManager:
from .sources.volcengine_tts import ( from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS, ProviderVolcengineTTS as ProviderVolcengineTTS,
) )
case "gemini_tts":
from .sources.gemini_tts_source import (
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
)
case "openai_embedding": case "openai_embedding":
from .sources.openai_embedding_source import ( from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider, OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
@@ -326,8 +327,6 @@ class ProviderManager:
inst = provider_metadata.cls_type( inst = provider_metadata.cls_type(
provider_config, provider_config,
self.provider_settings, self.provider_settings,
self.db_helper,
self.provider_settings.get("persistant_history", True),
self.selected_default_persona, self.selected_default_persona,
) )
@@ -420,7 +419,7 @@ class ProviderManager:
self.curr_tts_provider_inst = None self.curr_tts_provider_inst = None
if getattr(self.inst_map[provider_id], "terminate", None): if getattr(self.inst_map[provider_id], "terminate", None):
await self.inst_map[provider_id].terminate() await self.inst_map[provider_id].terminate() # type: ignore
logger.info( logger.info(
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
@@ -430,6 +429,8 @@ class ProviderManager:
async def terminate(self): async def terminate(self):
for provider_inst in self.provider_insts: for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"): if hasattr(provider_inst, "terminate"):
await provider_inst.terminate() await provider_inst.terminate() # type: ignore
# 清理 MCP Client 连接 try:
await self.llm_tools.mcp_service_queue.put({"type": "terminate"}) await self.llm_tools.disable_mcp_server()
except Exception:
logger.error("Error while disabling MCP servers", exc_info=True)

View File

@@ -1,9 +1,9 @@
import abc import abc
from typing import List from typing import List
from astrbot.core.db import BaseDatabase
from typing import TypedDict, AsyncGenerator from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
from astrbot.core.provider.register import provider_cls_map
from dataclasses import dataclass from dataclasses import dataclass
@@ -23,6 +23,7 @@ class ProviderMeta:
id: str id: str
model: str model: str
type: str type: str
provider_type: ProviderType
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
@@ -41,10 +42,14 @@ class AbstractProvider(abc.ABC):
def meta(self) -> ProviderMeta: def meta(self) -> ProviderMeta:
"""获取 Provider 的元数据""" """获取 Provider 的元数据"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
provider_type = meta_data.provider_type if meta_data else None
return ProviderMeta( return ProviderMeta(
id=self.provider_config["id"], id=self.provider_config["id"],
model=self.get_model(), model=self.get_model(),
type=self.provider_config["type"], type=provider_type_name,
provider_type=provider_type,
) )
@@ -53,15 +58,13 @@ class Provider(AbstractProvider):
self, self,
provider_config: dict, provider_config: dict,
provider_settings: dict, provider_settings: dict,
persistant_history: bool = True, default_persona: Personality | None = None,
db_helper: BaseDatabase = None,
default_persona: Personality = None,
) -> None: ) -> None:
super().__init__(provider_config) super().__init__(provider_config)
self.provider_settings = provider_settings self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona self.curr_personality = default_persona
"""维护了当前的使用的 persona即人格。可能为 None""" """维护了当前的使用的 persona即人格。可能为 None"""
@abc.abstractmethod @abc.abstractmethod
@@ -86,11 +89,12 @@ class Provider(AbstractProvider):
self, self,
prompt: str, prompt: str,
session_id: str = None, session_id: str = None,
image_urls: List[str] = None, image_urls: list[str] = None,
func_tool: FuncCall = None, func_tool: FuncCall = None,
contexts: List = None, contexts: list = None,
system_prompt: str = None, system_prompt: str = None,
tool_calls_result: ToolCallsResult = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。 """获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -114,11 +118,12 @@ class Provider(AbstractProvider):
self, self,
prompt: str, prompt: str,
session_id: str = None, session_id: str = None,
image_urls: List[str] = None, image_urls: list[str] = None,
func_tool: FuncCall = None, func_tool: FuncCall = None,
contexts: List = None, contexts: list = None,
system_prompt: str = None, system_prompt: str = None,
tool_calls_result: ToolCallsResult = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs, **kwargs,
) -> AsyncGenerator[LLMResponse, None]: ) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。

View File

@@ -1,3 +1,6 @@
import json
import anthropic
import base64
from typing import List from typing import List
from mimetypes import guess_type from mimetypes import guess_type
@@ -5,41 +8,33 @@ from anthropic import AsyncAnthropic
from anthropic.types import Message from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider
from astrbot.api.provider import Provider, Personality
from astrbot import logger from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from typing import AsyncGenerator
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter( @register_provider_adapter(
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器" "anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
) )
class ProviderAnthropic(ProviderOpenAIOfficial): class ProviderAnthropic(Provider):
def __init__( def __init__(
self, self,
provider_config: dict, provider_config,
provider_settings: dict, provider_settings,
db_helper: BaseDatabase, default_persona=None,
persistant_history=True,
default_persona: Personality = None,
) -> None: ) -> None:
# Skip OpenAI's __init__ and call Provider's __init__ directly super().__init__(
Provider.__init__(
self,
provider_config, provider_config,
provider_settings, provider_settings,
persistant_history,
db_helper,
default_persona, default_persona,
) )
self.chosen_api_key = None self.chosen_api_key: str = ""
self.api_keys: List = provider_config.get("key", []) self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com") self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120) self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str): if isinstance(self.timeout, str):
@@ -51,10 +46,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
self.set_model(provider_config["model_config"]["model"]) self.set_model(provider_config["model_config"]["model"])
def _prepare_payload(self, messages: list[dict]):
"""准备 Anthropic API 的请求 payload
Args:
messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息
Returns:
system_prompt: 系统提示内容
new_messages: 处理后的消息列表,去除系统提示
"""
system_prompt = ""
new_messages = []
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
elif message["role"] == "assistant":
blocks = []
if isinstance(message["content"], str):
blocks.append({"type": "text", "text": message["content"]})
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
blocks.append( # noqa: PERF401
{
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"],
"id": tool_call["id"],
}
)
new_messages.append(
{
"role": "assistant",
"content": blocks,
}
)
elif message["role"] == "tool":
new_messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"],
}
],
}
)
else:
new_messages.append(message)
return system_prompt, new_messages
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools: if tools:
tool_list = tools.get_func_desc_anthropic_style() if tool_list := tools.get_func_desc_anthropic_style():
if tool_list:
payloads["tools"] = tool_list payloads["tools"] = tool_list
completion = await self.client.messages.create(**payloads, stream=False) completion = await self.client.messages.create(**payloads, stream=False)
@@ -64,70 +112,158 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if len(completion.content) == 0: if len(completion.content) == 0:
raise Exception("API 返回的 completion 为空。") raise Exception("API 返回的 completion 为空。")
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
# 选最后一条消息如果要进行函数调用anthropic会先返回文本消息的思维链然后再返回函数调用请求
content = completion.content[-1]
llm_response = LLMResponse("assistant") llm_response = LLMResponse(role="assistant")
if content.type == "text": for content_block in completion.content:
# text completion if content_block.type == "text":
completion_text = str(content.text).strip() completion_text = str(content_block.text).strip()
# llm_response.completion_text = completion_text llm_response.completion_text = completion_text
llm_response.result_chain = MessageChain().message(completion_text)
# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_use_ids = []
func_name_ls.append(content.name)
args_ls.append(content.input)
tool_use_ids.append(content.id)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_use_ids
if content_block.type == "tool_use":
llm_response.tools_call_args.append(content_block.input)
llm_response.tools_call_name.append(content_block.name)
llm_response.tools_call_ids.append(content_block.id)
# TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args: if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}") raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion
return llm_response return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
) -> AsyncGenerator[LLMResponse, None]:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
# 用于累积工具调用信息
tool_use_buffer = {}
# 用于累积最终结果
final_text = ""
final_tool_calls = []
async with self.client.messages.stream(**payloads) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
if event.type == "content_block_start":
if event.content_block.type == "text":
# 文本块开始
yield LLMResponse(
role="assistant", completion_text="", is_chunk=True
)
elif event.content_block.type == "tool_use":
# 工具使用块开始,初始化缓冲区
tool_use_buffer[event.index] = {
"id": event.content_block.id,
"name": event.content_block.name,
"input": {},
}
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
# 文本增量
final_text += event.delta.text
yield LLMResponse(
role="assistant",
completion_text=event.delta.text,
is_chunk=True,
)
elif event.delta.type == "input_json_delta":
# 工具调用参数增量
if event.index in tool_use_buffer:
# 累积 JSON 输入
if "input_json" not in tool_use_buffer[event.index]:
tool_use_buffer[event.index]["input_json"] = ""
tool_use_buffer[event.index]["input_json"] += (
event.delta.partial_json
)
elif event.type == "content_block_stop":
# 内容块结束
if event.index in tool_use_buffer:
# 解析完整的工具调用
tool_info = tool_use_buffer[event.index]
try:
if "input_json" in tool_info:
tool_info["input"] = json.loads(tool_info["input_json"])
# 添加到最终结果
final_tool_calls.append(
{
"id": tool_info["id"],
"name": tool_info["name"],
"input": tool_info["input"],
}
)
yield LLMResponse(
role="tool",
completion_text="",
tools_call_args=[tool_info["input"]],
tools_call_name=[tool_info["name"]],
tools_call_ids=[tool_info["id"]],
is_chunk=True,
)
except json.JSONDecodeError:
# JSON 解析失败,跳过这个工具调用
logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}")
# 清理缓冲区
del tool_use_buffer[event.index]
# 返回最终的完整结果
final_response = LLMResponse(
role="assistant", completion_text=final_text, is_chunk=False
)
if final_tool_calls:
final_response.tools_call_args = [
call["input"] for call in final_tool_calls
]
final_response.tools_call_name = [call["name"] for call in final_tool_calls]
final_response.tools_call_ids = [call["id"] for call in final_tool_calls]
yield final_response
async def text_chat( async def text_chat(
self, self,
prompt: str, prompt,
session_id: str = None, session_id=None,
image_urls: List[str] = [], image_urls=None,
func_tool: FuncCall = None, func_tool=None,
contexts=None, contexts=None,
system_prompt=None, system_prompt=None,
tool_calls_result: ToolCallsResult = None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
if contexts is None: if contexts is None:
contexts = [] contexts = []
if not prompt:
prompt = "<image>"
new_record = await self.assemble_context(prompt, image_urls) new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record] context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query: for part in context_query:
if "_no_save" in part: if "_no_save" in part:
del part["_no_save"] del part["_no_save"]
# tool calls result
if tool_calls_result: if tool_calls_result:
# 暂时这样写。 if not isinstance(tool_calls_result, list):
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}" context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {}) model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
payloads = {"messages": context_query, **model_config}
# Anthropic has a different way of handling system prompts # Anthropic has a different way of handling system prompts
if system_prompt: if system_prompt:
payloads["system"] = system_prompt payloads["system"] = system_prompt
@@ -135,32 +271,9 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
llm_response = None llm_response = None
try: try:
llm_response = await self._query(payloads, func_tool) llm_response = await self._query(payloads, func_tool)
except Exception as e: except Exception as e:
if "maximum context length" in str(e): logger.error(f"发生了错误。Provider 配置如下: {model_config}")
retry_cnt = 20 raise e
while retry_cnt > 0:
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
try:
await self.pop_record(context_query)
response = await self.client.messages.create(
messages=context_query, **model_config
)
llm_response = LLMResponse("assistant")
llm_response.result_chain = MessageChain().message(response.content[0].text)
llm_response.raw_completion = response
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
else:
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response return llm_response
@@ -173,23 +286,41 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
contexts=..., contexts=...,
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
): ):
# raise NotImplementedError("This method is not implemented yet.") if contexts is None:
# 调用 text_chat 模拟流式 contexts = []
llm_response = await self.text_chat( new_record = await self.assemble_context(prompt, image_urls)
prompt=prompt, context_query = [*contexts, new_record]
session_id=session_id, if system_prompt:
image_urls=image_urls, context_query.insert(0, {"role": "system", "content": system_prompt})
func_tool=func_tool,
contexts=contexts, for part in context_query:
system_prompt=system_prompt, if "_no_save" in part:
tool_calls_result=tool_calls_result, del part["_no_save"]
)
llm_response.is_chunk = True # tool calls result
yield llm_response if tool_calls_result:
llm_response.is_chunk = False if not isinstance(tool_calls_result, list):
yield llm_response context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
# Anthropic has a different way of handling system prompts
if system_prompt:
payloads["system"] = system_prompt
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None): async def assemble_context(self, text: str, image_urls: List[str] = None):
"""组装上下文,支持文本和图片""" """组装上下文,支持文本和图片"""
@@ -232,3 +363,28 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
) )
return {"role": "user", "content": content} return {"role": "user", "content": content}
async def encode_image_bs64(self, image_url: str) -> str:
"""
将图片转换为 base64
"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
def get_current_key(self) -> str:
return self.chosen_api_key
async def get_models(self) -> List[str]:
models_str = []
models = await self.client.models.list()
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
return models_str
def set_key(self, key: str):
self.chosen_api_key = key

View File

@@ -19,6 +19,7 @@ from ..register import register_provider_adapter
TEMP_DIR = Path("data/temp/azure_tts") TEMP_DIR = Path("data/temp/azure_tts")
TEMP_DIR.mkdir(parents=True, exist_ok=True) TEMP_DIR.mkdir(parents=True, exist_ok=True)
class OTTSProvider: class OTTSProvider:
def __init__(self, config: Dict): def __init__(self, config: Dict):
self.skey = config["OTTS_SKEY"] self.skey = config["OTTS_SKEY"]
@@ -70,12 +71,12 @@ class OTTSProvider:
"style": voice_params["style"], "style": voice_params["style"],
"role": voice_params["role"], "role": voice_params["role"],
"rate": voice_params["rate"], "rate": voice_params["rate"],
"volume": voice_params["volume"] "volume": voice_params["volume"],
}, },
headers={ headers={
"User-Agent": f"AstrBot/{VERSION}", "User-Agent": f"AstrBot/{VERSION}",
"UAK": "AstrBot/AzureTTS" "UAK": "AstrBot/AzureTTS",
} },
) )
response.raise_for_status() response.raise_for_status()
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -88,14 +89,19 @@ class OTTSProvider:
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
class AzureNativeProvider(TTSProvider): class AzureNativeProvider(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict): def __init__(self, provider_config: dict, provider_settings: dict):
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip() self.subscription_key = provider_config.get(
"azure_tts_subscription_key", ""
).strip()
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key): if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
raise ValueError("无效的Azure订阅密钥") raise ValueError("无效的Azure订阅密钥")
self.region = provider_config.get("azure_tts_region", "eastus").strip() self.region = provider_config.get("azure_tts_region", "eastus").strip()
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
)
self.client = None self.client = None
self.token = None self.token = None
self.token_expire = 0 self.token_expire = 0
@@ -104,15 +110,17 @@ class AzureNativeProvider(TTSProvider):
"style": provider_config.get("azure_tts_style", "cheerful"), "style": provider_config.get("azure_tts_style", "cheerful"),
"role": provider_config.get("azure_tts_role", "Boy"), "role": provider_config.get("azure_tts_role", "Boy"),
"rate": provider_config.get("azure_tts_rate", "1"), "rate": provider_config.get("azure_tts_rate", "1"),
"volume": provider_config.get("azure_tts_volume", "100") "volume": provider_config.get("azure_tts_volume", "100"),
} }
async def __aenter__(self): async def __aenter__(self):
self.client = AsyncClient(headers={ self.client = AsyncClient(
"User-Agent": f"AstrBot/{VERSION}", headers={
"Content-Type": "application/ssml+xml", "User-Agent": f"AstrBot/{VERSION}",
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm" "Content-Type": "application/ssml+xml",
}) "X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
}
)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -120,10 +128,11 @@ class AzureNativeProvider(TTSProvider):
await self.client.aclose() await self.client.aclose()
async def _refresh_token(self): async def _refresh_token(self):
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" token_url = (
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
)
response = await self.client.post( response = await self.client.post(
token_url, token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
) )
response.raise_for_status() response.raise_for_status()
self.token = response.text self.token = response.text
@@ -150,8 +159,8 @@ class AzureNativeProvider(TTSProvider):
content=ssml, content=ssml,
headers={ headers={
"Authorization": f"Bearer {self.token}", "Authorization": f"Bearer {self.token}",
"User-Agent": f"AstrBot/{VERSION}" "User-Agent": f"AstrBot/{VERSION}",
} },
) )
response.raise_for_status() response.raise_for_status()
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -160,6 +169,7 @@ class AzureNativeProvider(TTSProvider):
f.write(chunk) f.write(chunk)
return str(file_path.resolve()) return str(file_path.resolve())
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH) @register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
class AzureTTSProvider(TTSProvider): class AzureTTSProvider(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict): def __init__(self, provider_config: dict, provider_settings: dict):
@@ -183,7 +193,7 @@ class AzureTTSProvider(TTSProvider):
error_msg = ( error_msg = (
f"JSON解析失败请检查格式错误位置{e.lineno}{e.colno}\n" f"JSON解析失败请检查格式错误位置{e.lineno}{e.colno}\n"
f"错误详情: {e.msg}\n" f"错误详情: {e.msg}\n"
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}" f"错误上下文: {json_str[max(0, e.pos - 30) : e.pos + 30]}"
) )
raise ValueError(error_msg) from e raise ValueError(error_msg) from e
except KeyError as e: except KeyError as e:
@@ -202,8 +212,8 @@ class AzureTTSProvider(TTSProvider):
"style": self.provider_config.get("azure_tts_style"), "style": self.provider_config.get("azure_tts_style"),
"role": self.provider_config.get("azure_tts_role"), "role": self.provider_config.get("azure_tts_role"),
"rate": self.provider_config.get("azure_tts_rate"), "rate": self.provider_config.get("azure_tts_rate"),
"volume": self.provider_config.get("azure_tts_volume") "volume": self.provider_config.get("azure_tts_volume"),
} },
) )
else: else:
async with self.provider as provider: async with self.provider as provider:

View File

@@ -5,7 +5,6 @@ from typing import List
from .. import Provider, Personality from .. import Provider, Personality
from ..entities import LLMResponse from ..entities import LLMResponse
from ..func_tool_manager import FuncCall from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial from .openai_source import ProviderOpenAIOfficial
@@ -19,16 +18,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
self, self,
provider_config: dict, provider_config: dict,
provider_settings: dict, provider_settings: dict,
db_helper: BaseDatabase, default_persona: Personality | None = None,
persistant_history=False,
default_persona: Personality = None,
) -> None: ) -> None:
Provider.__init__( Provider.__init__(
self, self,
provider_config, provider_config,
provider_settings, provider_settings,
persistant_history,
db_helper,
default_persona, default_persona,
) )
self.api_key = provider_config.get("dashscope_api_key", "") self.api_key = provider_config.get("dashscope_api_key", "")
@@ -72,6 +67,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
func_tool: FuncCall = None, func_tool: FuncCall = None,
contexts: List = None, contexts: List = None,
system_prompt: str = None, system_prompt: str = None,
model=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
if contexts is None: if contexts is None:
@@ -168,6 +164,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
contexts=..., contexts=...,
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
): ):
# raise NotImplementedError("This method is not implemented yet.") # raise NotImplementedError("This method is not implemented yet.")

View File

@@ -1,10 +1,9 @@
import astrbot.core.message.components as Comp import astrbot.core.message.components as Comp
import os import os
from typing import List from typing import List
from .. import Provider, Personality from .. import Provider
from ..entities import LLMResponse from ..entities import LLMResponse
from ..func_tool_manager import FuncCall from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.core.utils.io import download_image_by_url, download_file
@@ -17,17 +16,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class ProviderDify(Provider): class ProviderDify(Provider):
def __init__( def __init__(
self, self,
provider_config: dict, provider_config,
provider_settings: dict, provider_settings,
db_helper: BaseDatabase, default_persona=None,
persistant_history=False,
default_persona: Personality = None,
) -> None: ) -> None:
super().__init__( super().__init__(
provider_config, provider_config,
provider_settings, provider_settings,
persistant_history,
db_helper,
default_persona, default_persona,
) )
self.api_key = provider_config.get("dify_api_key", "") self.api_key = provider_config.get("dify_api_key", "")
@@ -65,12 +60,14 @@ class ProviderDify(Provider):
func_tool: FuncCall = None, func_tool: FuncCall = None,
contexts: List = None, contexts: List = None,
system_prompt: str = None, system_prompt: str = None,
tool_calls_result=None,
model=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
if image_urls is None: if image_urls is None:
image_urls = [] image_urls = []
result = "" result = ""
session_id = session_id or kwargs.get("user") # 1734 session_id = session_id or kwargs.get("user") or "unknown" # 1734
conversation_id = self.conversation_ids.get(session_id, "") conversation_id = self.conversation_ids.get(session_id, "")
files_payload = [] files_payload = []
@@ -103,6 +100,7 @@ class ProviderDify(Provider):
session_vars = sp.get("session_variables", {}) session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {}) session_var = session_vars.get(session_id, {})
payload_vars.update(session_var) payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
try: try:
match self.api_type: match self.api_type:
@@ -202,6 +200,7 @@ class ProviderDify(Provider):
contexts=..., contexts=...,
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
): ):
# raise NotImplementedError("This method is not implemented yet.") # raise NotImplementedError("This method is not implemented yet.")

View File

@@ -12,10 +12,9 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp import astrbot.core.message.components as Comp
from astrbot import logger from astrbot import logger
from astrbot.api.provider import Personality, Provider from astrbot.api.provider import Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
@@ -52,17 +51,13 @@ class ProviderGoogleGenAI(Provider):
def __init__( def __init__(
self, self,
provider_config: dict, provider_config,
provider_settings: dict, provider_settings,
db_helper: BaseDatabase, default_persona=None,
persistant_history=True,
default_persona: Personality = None,
) -> None: ) -> None:
super().__init__( super().__init__(
provider_config, provider_config,
provider_settings, provider_settings,
persistant_history,
db_helper,
default_persona, default_persona,
) )
self.api_keys: list = provider_config.get("key", []) self.api_keys: list = provider_config.get("key", [])
@@ -475,6 +470,10 @@ class ProviderGoogleGenAI(Provider):
raise raise
continue continue
# Accumulate the complete response text for the final response
accumulated_text = ""
final_response = None
async for chunk in result: async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True) llm_response = LLMResponse("assistant", is_chunk=True)
@@ -486,32 +485,47 @@ class ProviderGoogleGenAI(Provider):
chunk, llm_response chunk, llm_response
) )
yield llm_response yield llm_response
break return
if chunk.text: if chunk.text:
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)]) llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response yield llm_response
if chunk.candidates[0].finish_reason: if chunk.candidates[0].finish_reason:
llm_response = LLMResponse("assistant", is_chunk=False) # Process the final chunk for potential tool calls or other content
if not chunk.candidates[0].content.parts: if chunk.candidates[0].content.parts:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")]) final_response = LLMResponse("assistant", is_chunk=False)
else: final_response.result_chain = self._process_content_parts(
llm_response.result_chain = self._process_content_parts( chunk, final_response
chunk, llm_response
) )
yield llm_response
break break
# Yield final complete response with accumulated text
if not final_response:
final_response = LLMResponse("assistant", is_chunk=False)
# Set the complete accumulated text in the final response
if accumulated_text:
final_response.result_chain = MessageChain(
chain=[Comp.Plain(accumulated_text)]
)
elif not final_response.result_chain:
# If no text was accumulated and no final response was set, provide empty space
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
yield final_response
async def text_chat( async def text_chat(
self, self,
prompt: str, prompt: str,
session_id: str = None, session_id=None,
image_urls: list[str] = None, image_urls=None,
func_tool: FuncCall = None, func_tool=None,
contexts: list = None, contexts=None,
system_prompt: str = None, system_prompt=None,
tool_calls_result: ToolCallsResult = None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
if contexts is None: if contexts is None:
@@ -527,10 +541,14 @@ class ProviderGoogleGenAI(Provider):
# tool calls result # tool calls result
if tool_calls_result: if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages()) if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {}) model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model() model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config} payloads = {"messages": context_query, **model_config}
@@ -547,13 +565,14 @@ class ProviderGoogleGenAI(Provider):
async def text_chat_stream( async def text_chat_stream(
self, self,
prompt: str, prompt,
session_id: str = None, session_id=None,
image_urls: list[str] = None, image_urls=None,
func_tool: FuncCall = None, func_tool=None,
contexts: str = None, contexts=None,
system_prompt: str = None, system_prompt=None,
tool_calls_result: ToolCallsResult = None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
) -> AsyncGenerator[LLMResponse, None]: ) -> AsyncGenerator[LLMResponse, None]:
if contexts is None: if contexts is None:
@@ -569,10 +588,14 @@ class ProviderGoogleGenAI(Provider):
# tool calls result # tool calls result
if tool_calls_result: if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages()) if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {}) model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model() model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config} payloads = {"messages": context_query, **model_config}
@@ -632,7 +655,10 @@ class ProviderGoogleGenAI(Provider):
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue continue
user_content["content"].append( user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}} {
"type": "image_url",
"image_url": {"url": image_data},
}
) )
return user_content return user_content
else: else:

View File

@@ -0,0 +1,79 @@
import os
import uuid
import wave
from google import genai
from google.genai import types
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
@register_provider_adapter(
"gemini_tts", "Gemini TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
)
class ProviderGeminiTTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
api_key: str = provider_config.get("gemini_tts_api_key", "")
api_base: str | None = provider_config.get("gemini_tts_api_base")
timeout: int = int(provider_config.get("gemini_tts_timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
if api_base:
if api_base.endswith("/"):
api_base = api_base[:-1]
http_options.base_url = api_base
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
self.model: str = provider_config.get(
"gemini_tts_model", "gemini-2.5-flash-preview-tts"
)
self.prefix: str | None = provider_config.get(
"gemini_tts_prefix",
)
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
prompt = f"{self.prefix}: {text}" if self.prefix else text
response = await self.client.models.generate_content(
model=self.model,
contents=prompt,
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=self.voice_name,
)
)
),
),
)
# 不想看类型检查报错
if (
not response.candidates
or not response.candidates[0].content
or not response.candidates[0].content.parts
or not response.candidates[0].content.parts[0].inline_data
or not response.candidates[0].content.parts[0].inline_data.data
):
raise Exception("No audio content returned from Gemini TTS API.")
with wave.open(path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(24000)
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
return path

View File

@@ -1,134 +0,0 @@
import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
@register_provider_adapter(
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
)
class LLMTunerModelLoader(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
provider_config["adapter_model_path"]
):
raise FileNotFoundError("模型文件路径不存在。")
self.base_model_path = provider_config["base_model_path"]
self.adapter_model_path = provider_config["adapter_model_path"]
self.model = ChatModel(
{
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config["llmtuner_template"],
"finetuning_type": provider_config["finetuning_type"],
"quantization_bit": provider_config["quantization_bit"],
}
)
self.set_model(
os.path.basename(self.base_model_path)
+ "_"
+ os.path.basename(self.adapter_model_path)
)
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""
组装上下文。
"""
return {"role": "user", "content": text}
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
system_prompt = ""
new_record = {"role": "user", "content": prompt}
query_context = [*contexts, new_record]
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(query_context):
if context["role"] == "system":
system_idxs.append(idx)
if "_no_save" in context:
del context["_no_save"]
for idx in reversed(system_idxs):
system_prompt += " " + query_context.pop(idx)["content"]
conf = {
"messages": query_context,
"system": system_prompt,
}
if func_tool:
tool_list = func_tool.get_func_desc_openai_style()
if tool_list:
conf["tools"] = tool_list
responses = await self.model.achat(**conf)
llm_response = LLMResponse("assistant", responses[-1].response_text)
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def get_current_key(self):
return "none"
async def set_key(self, key):
pass
async def get_models(self):
return [self.get_model()]

View File

@@ -22,7 +22,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
timeout=int(provider_config.get("timeout", 20)), timeout=int(provider_config.get("timeout", 20)),
) )
self.model = provider_config.get("embedding_model", "text-embedding-3-small") self.model = provider_config.get("embedding_model", "text-embedding-3-small")
self.dimension = provider_config.get("embedding_dimensions", 1536) self.dimension = provider_config.get("embedding_dimensions", 1024)
async def get_embedding(self, text: str) -> list[float]: async def get_embedding(self, text: str) -> list[float]:
""" """

View File

@@ -9,14 +9,12 @@ import astrbot.core.message.components as Comp
from openai import AsyncOpenAI, AsyncAzureOpenAI from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai._exceptions import NotFoundError, UnprocessableEntityError from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider
from astrbot.api.provider import Provider, Personality
from astrbot import logger from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List, AsyncGenerator from typing import List, AsyncGenerator
@@ -30,17 +28,13 @@ from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
class ProviderOpenAIOfficial(Provider): class ProviderOpenAIOfficial(Provider):
def __init__( def __init__(
self, self,
provider_config: dict, provider_config,
provider_settings: dict, provider_settings,
db_helper: BaseDatabase, default_persona=None,
persistant_history=True,
default_persona: Personality = None,
) -> None: ) -> None:
super().__init__( super().__init__(
provider_config, provider_config,
provider_settings, provider_settings,
persistant_history,
db_helper,
default_persona, default_persona,
) )
self.chosen_api_key = None self.chosen_api_key = None
@@ -105,6 +99,11 @@ class ProviderOpenAIOfficial(Provider):
for key in to_del: for key in to_del:
del payloads[key] del payloads[key]
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
model = payloads.get("model", "")
if "qwen3" in model.lower():
extra_body["enable_thinking"] = False
completion = await self.client.chat.completions.create( completion = await self.client.chat.completions.create(
**payloads, stream=False, extra_body=extra_body **payloads, stream=False, extra_body=extra_body
) )
@@ -182,7 +181,7 @@ class ProviderOpenAIOfficial(Provider):
raise Exception("API 返回的 completion 为空。") raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0] choice = completion.choices[0]
if choice.message.content: if choice.message.content is not None:
# text completion # text completion
completion_text = str(choice.message.content).strip() completion_text = str(choice.message.content).strip()
llm_response.result_chain = MessageChain().message(completion_text) llm_response.result_chain = MessageChain().message(completion_text)
@@ -193,6 +192,9 @@ class ProviderOpenAIOfficial(Provider):
func_name_ls = [] func_name_ls = []
tool_call_ids = [] tool_call_ids = []
for tool_call in choice.message.tool_calls: for tool_call in choice.message.tool_calls:
if isinstance(tool_call, str):
# workaround for #1359
tool_call = json.loads(tool_call)
for tool in tools.func_list: for tool in tools.func_list:
if tool.name == tool_call.function.name: if tool.name == tool_call.function.name:
# workaround for #1454 # workaround for #1454
@@ -213,7 +215,7 @@ class ProviderOpenAIOfficial(Provider):
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。" "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
) )
if not llm_response.completion_text and not llm_response.tools_call_args: if llm_response.completion_text is None and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}") logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}") raise Exception(f"API 返回的 completion 无法解析:{completion}")
@@ -224,12 +226,11 @@ class ProviderOpenAIOfficial(Provider):
async def _prepare_chat_payload( async def _prepare_chat_payload(
self, self,
prompt: str, prompt: str,
session_id: str = None, image_urls: list[str] | None = None,
image_urls: list[str] = None, contexts: list | None = None,
func_tool: FuncCall = None, system_prompt: str | None = None,
contexts: list = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
system_prompt: str = None, model: str | None = None,
tool_calls_result: ToolCallsResult = None,
**kwargs, **kwargs,
) -> tuple: ) -> tuple:
"""准备聊天所需的有效载荷和上下文""" """准备聊天所需的有效载荷和上下文"""
@@ -246,14 +247,18 @@ class ProviderOpenAIOfficial(Provider):
# tool calls result # tool calls result
if tool_calls_result: if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages()) if isinstance(tool_calls_result, ToolCallsResult):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {}) model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model() model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config} payloads = {"messages": context_query, **model_config}
return payloads, context_query, func_tool return payloads, context_query
async def _handle_api_error( async def _handle_api_error(
self, self,
@@ -350,16 +355,16 @@ class ProviderOpenAIOfficial(Provider):
contexts=None, contexts=None,
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
payloads, context_query, func_tool = await self._prepare_chat_payload( payloads, context_query = await self._prepare_chat_payload(
prompt, prompt,
session_id,
image_urls, image_urls,
func_tool,
contexts, contexts,
system_prompt, system_prompt,
tool_calls_result, tool_calls_result,
model=model,
**kwargs, **kwargs,
) )
@@ -419,17 +424,17 @@ class ProviderOpenAIOfficial(Provider):
contexts=[], contexts=[],
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None,
**kwargs, **kwargs,
) -> AsyncGenerator[LLMResponse, None]: ) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果""" """流式对话,与服务商交互并逐步返回结果"""
payloads, context_query, func_tool = await self._prepare_chat_payload( payloads, context_query = await self._prepare_chat_payload(
prompt, prompt,
session_id,
image_urls, image_urls,
func_tool,
contexts, contexts,
system_prompt, system_prompt,
tool_calls_result, tool_calls_result,
model=model,
**kwargs, **kwargs,
) )
@@ -485,13 +490,8 @@ class ProviderOpenAIOfficial(Provider):
""" """
new_contexts = [] new_contexts = []
flag = False
for context in contexts: for context in contexts:
if flag: if "content" in context and isinstance(context["content"], list):
flag = False # 删除 image 后下一条LLM 响应)也要删除
continue
if isinstance(context["content"], list):
flag = True
# continue # continue
new_content = [] new_content = []
for item in context["content"]: for item in context["content"]:
@@ -534,7 +534,10 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue continue
user_content["content"].append( user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}} {
"type": "image_url",
"image_url": {"url": image_data},
}
) )
return user_content return user_content
else: else:

View File

@@ -5,12 +5,12 @@ import os
import traceback import traceback
import asyncio import asyncio
import aiohttp import aiohttp
import requests
from ..provider import TTSProvider from ..provider import TTSProvider
from ..entities import ProviderType from ..entities import ProviderType
from ..register import register_provider_adapter from ..register import register_provider_adapter
from astrbot import logger from astrbot import logger
@register_provider_adapter( @register_provider_adapter(
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH "volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
) )
@@ -22,7 +22,9 @@ class ProviderVolcengineTTS(TTSProvider):
self.cluster = provider_config.get("volcengine_cluster", "") self.cluster = provider_config.get("volcengine_cluster", "")
self.voice_type = provider_config.get("volcengine_voice_type", "") self.voice_type = provider_config.get("volcengine_voice_type", "")
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0) self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts") self.api_base = provider_config.get(
"api_base", "https://openspeech.bytedance.com/api/v1/tts"
)
self.timeout = provider_config.get("timeout", 20) self.timeout = provider_config.get("timeout", 20)
def _build_request_payload(self, text: str) -> dict: def _build_request_payload(self, text: str) -> dict:
@@ -30,11 +32,9 @@ class ProviderVolcengineTTS(TTSProvider):
"app": { "app": {
"appid": self.appid, "appid": self.appid,
"token": self.api_key, "token": self.api_key,
"cluster": self.cluster "cluster": self.cluster,
},
"user": {
"uid": str(uuid.uuid4())
}, },
"user": {"uid": str(uuid.uuid4())},
"audio": { "audio": {
"voice_type": self.voice_type, "voice_type": self.voice_type,
"encoding": "mp3", "encoding": "mp3",
@@ -48,15 +48,15 @@ class ProviderVolcengineTTS(TTSProvider):
"text_type": "plain", "text_type": "plain",
"operation": "query", "operation": "query",
"with_frontend": 1, "with_frontend": 1,
"frontend_type": "unitTson" "frontend_type": "unitTson",
} },
} }
async def get_audio(self, text: str) -> str: async def get_audio(self, text: str) -> str:
"""异步方法获取语音文件路径""" """异步方法获取语音文件路径"""
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer; {self.api_key}" "Authorization": f"Bearer; {self.api_key}",
} }
payload = self._build_request_payload(text) payload = self._build_request_payload(text)
@@ -71,7 +71,7 @@ class ProviderVolcengineTTS(TTSProvider):
self.api_base, self.api_base,
data=json.dumps(payload), data=json.dumps(payload),
headers=headers, headers=headers,
timeout=self.timeout timeout=self.timeout,
) as response: ) as response:
logger.debug(f"响应状态码: {response.status}") logger.debug(f"响应状态码: {response.status}")
@@ -90,8 +90,7 @@ class ProviderVolcengineTTS(TTSProvider):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.run_in_executor( await loop.run_in_executor(
None, None, lambda: open(file_path, "wb").write(audio_data)
lambda: open(file_path, "wb").write(audio_data)
) )
return file_path return file_path
@@ -99,7 +98,9 @@ class ProviderVolcengineTTS(TTSProvider):
error_msg = resp_data.get("message", "未知错误") error_msg = resp_data.get("message", "未知错误")
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}") raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
else: else:
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}") raise Exception(
f"火山引擎 TTS API 请求失败: {response.status}, {response_text}"
)
except Exception as e: except Exception as e:
error_details = traceback.format_exc() error_details = traceback.format_exc()

View File

@@ -1,4 +1,3 @@
from astrbot.core.db import BaseDatabase
from astrbot import logger from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List from typing import List
@@ -13,15 +12,11 @@ class ProviderZhipu(ProviderOpenAIOfficial):
self, self,
provider_config: dict, provider_config: dict,
provider_settings: dict, provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona=None, default_persona=None,
) -> None: ) -> None:
super().__init__( super().__init__(
provider_config, provider_config,
provider_settings, provider_settings,
db_helper,
persistant_history,
default_persona, default_persona,
) )
@@ -33,6 +28,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
func_tool: FuncCall = None, func_tool: FuncCall = None,
contexts=None, contexts=None,
system_prompt=None, system_prompt=None,
model=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
if contexts is None: if contexts is None:
@@ -43,7 +39,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
context_query = [*contexts, new_record] context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {}) model_cfgs: dict = self.provider_config.get("model_config", {})
model = self.get_model() model = model or self.get_model()
# glm-4v-flash 只支持一张图片 # glm-4v-flash 只支持一张图片
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1: if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片") logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")

View File

@@ -1,4 +1,4 @@
from .star import StarMetadata from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager from .star_manager import PluginManager
from .context import Context from .context import Context
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
@@ -10,23 +10,48 @@ from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin): class Star(CommandParserMixin):
"""所有插件Star的父类所有插件都应该继承于这个类""" """所有插件Star的父类所有插件都应该继承于这个类"""
def __init__(self, context: Context): def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context) StarTools.initialize(context)
self.context = context self.context = context
async def text_to_image(self, text: str, return_url=True) -> str: def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
@staticmethod
async def text_to_image(text: str, return_url=True) -> str:
"""将文本转换为图片""" """将文本转换为图片"""
return await html_renderer.render_t2i(text, return_url=return_url) return await html_renderer.render_t2i(text, return_url=return_url)
async def html_render(self, tmpl: str, data: dict, return_url=True) -> str: @staticmethod
async def html_render(
tmpl: str, data: dict, return_url=True, options: dict = None
) -> str:
"""渲染 HTML""" """渲染 HTML"""
return await html_renderer.render_custom_template( return await html_renderer.render_custom_template(
tmpl, data, return_url=return_url tmpl, data, return_url=return_url, options=options
) )
async def initialize(self):
"""当插件被激活时会调用这个方法"""
pass
async def terminate(self): async def terminate(self):
"""当插件被禁用、重载插件时会调用这个方法""" """当插件被禁用、重载插件时会调用这个方法"""
pass pass
def __del__(self):
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
pass
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"] __all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]

View File

@@ -2,7 +2,12 @@ from asyncio import Queue
from typing import List, Union from typing import List, Union
from astrbot.core import sp from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider from astrbot.core.provider.provider import (
Provider,
TTSProvider,
STTProvider,
EmbeddingProvider,
)
from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -141,6 +146,10 @@ class Context:
"""获取所有用于 STT 任务的 Provider。""" """获取所有用于 STT 任务的 Provider。"""
return self.provider_manager.stt_provider_insts return self.provider_manager.stt_provider_insts
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str = None) -> Provider: def get_using_provider(self, umo: str = None) -> Provider:
""" """
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。

View File

@@ -7,10 +7,13 @@ from astrbot.core.config import AstrBotConfig
from .custom_filter import CustomFilter from .custom_filter import CustomFilter
from ..star_handler import StarHandlerMetadata from ..star_handler import StarHandlerMetadata
class GreedyStr(str): class GreedyStr(str):
"""标记指令完成其他参数接收后的所有剩余文本。""" """标记指令完成其他参数接收后的所有剩余文本。"""
pass pass
# 标准指令受到 wake_prefix 的制约。 # 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter): class CommandFilter(HandlerFilter):
"""标准指令过滤器""" """标准指令过滤器"""
@@ -18,8 +21,8 @@ class CommandFilter(HandlerFilter):
def __init__( def __init__(
self, self,
command_name: str, command_name: str,
alias: set = None, alias: set | None = None,
handler_md: StarHandlerMetadata = None, handler_md: StarHandlerMetadata | None = None,
parent_command_names: List[str] = [""], parent_command_names: List[str] = [""],
): ):
self.command_name = command_name self.command_name = command_name
@@ -110,6 +113,17 @@ class CommandFilter(HandlerFilter):
elif isinstance(param_type_or_default_val, str): elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值 # 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i] result[param_name] = params[i]
elif isinstance(param_type_or_default_val, bool):
# 处理布尔类型
lower_param = str(params[i]).lower()
if lower_param in ["true", "yes", "1"]:
result[param_name] = True
elif lower_param in ["false", "no", "0"]:
result[param_name] = False
else:
raise ValueError(
f"参数 {param_name} 必须是布尔值true/false, yes/no, 1/0"
)
elif isinstance(param_type_or_default_val, int): elif isinstance(param_type_or_default_val, int):
result[param_name] = int(params[i]) result[param_name] = int(params[i])
elif isinstance(param_type_or_default_val, float): elif isinstance(param_type_or_default_val, float):

View File

@@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter):
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
) )
raise ValueError( raise ValueError(
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
+ tree
) )
# complete_command_names = [name + " " for name in complete_command_names] # complete_command_names = [name + " " for name in complete_command_names]

View File

@@ -8,22 +8,45 @@ from typing import Union
class PlatformAdapterType(enum.Flag): class PlatformAdapterType(enum.Flag):
AIOCQHTTP = enum.auto() AIOCQHTTP = enum.auto()
QQOFFICIAL = enum.auto() QQOFFICIAL = enum.auto()
VCHAT = enum.auto()
GEWECHAT = enum.auto()
TELEGRAM = enum.auto() TELEGRAM = enum.auto()
WECOM = enum.auto() WECOM = enum.auto()
LARK = enum.auto() LARK = enum.auto()
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT | TELEGRAM | WECOM | LARK WECHATPADPRO = enum.auto()
DINGTALK = enum.auto()
DISCORD = enum.auto()
SLACK = enum.auto()
KOOK = enum.auto()
VOCECHAT = enum.auto()
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
ALL = (
AIOCQHTTP
| QQOFFICIAL
| TELEGRAM
| WECOM
| LARK
| WECHATPADPRO
| DINGTALK
| DISCORD
| SLACK
| KOOK
| VOCECHAT
| WEIXIN_OFFICIAL_ACCOUNT
)
ADAPTER_NAME_2_TYPE = { ADAPTER_NAME_2_TYPE = {
"aiocqhttp": PlatformAdapterType.AIOCQHTTP, "aiocqhttp": PlatformAdapterType.AIOCQHTTP,
"qq_official": PlatformAdapterType.QQOFFICIAL, "qq_official": PlatformAdapterType.QQOFFICIAL,
"vchat": PlatformAdapterType.VCHAT,
"gewechat": PlatformAdapterType.GEWECHAT,
"telegram": PlatformAdapterType.TELEGRAM, "telegram": PlatformAdapterType.TELEGRAM,
"wecom": PlatformAdapterType.WECOM, "wecom": PlatformAdapterType.WECOM,
"lark": PlatformAdapterType.LARK, "lark": PlatformAdapterType.LARK,
"dingtalk": PlatformAdapterType.DINGTALK,
"discord": PlatformAdapterType.DISCORD,
"slack": PlatformAdapterType.SLACK,
"kook": PlatformAdapterType.KOOK,
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
"vocechat": PlatformAdapterType.VOCECHAT,
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
} }

View File

@@ -1,9 +1,17 @@
from ..star import star_registry, StarMetadata, star_map import warnings
from astrbot.core.star import StarMetadata, star_map
_warned_register_star = False
def register_star(name: str, author: str, desc: str, version: str, repo: str = None): def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
"""注册一个插件(Star)。 """注册一个插件(Star)。
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
在 v3.5.19 版本之后(不含),您不需要使用该装饰器来装饰插件类,
AstrBot 会自动识别继承自 Star 的类并将其作为插件类加载。
Args: Args:
name: 插件名称。 name: 插件名称。
author: 作者。 author: 作者。
@@ -21,18 +29,32 @@ def register_star(name: str, author: str, desc: str, version: str, repo: str = N
帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。` 帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。`
""" """
def decorator(cls): global _warned_register_star
star_metadata = StarMetadata( if not _warned_register_star:
name=name, _warned_register_star = True
author=author, warnings.warn(
desc=desc, "The 'register_star' decorator is deprecated and will be removed in a future version.",
version=version, DeprecationWarning,
repo=repo, stacklevel=2,
star_cls_type=cls,
module_path=cls.__module__,
) )
star_registry.append(star_metadata)
star_map[cls.__module__] = star_metadata def decorator(cls):
if not star_map.get(cls.__module__):
metadata = StarMetadata(
name=name,
author=author,
desc=desc,
version=version,
repo=repo,
)
star_map[cls.__module__] = metadata
else:
star_map[cls.__module__].name = name
star_map[cls.__module__].author = author
star_map[cls.__module__].desc = desc
star_map[cls.__module__].version = version
star_map[cls.__module__].repo = repo
return cls return cls
return decorator return decorator

View File

@@ -0,0 +1,293 @@
"""
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
"""
from typing import Dict
from astrbot.core import logger, sp
from astrbot.core.platform.astr_message_event import AstrMessageEvent
class SessionServiceManager:
"""管理会话级别的服务启停状态包括LLM和TTS"""
# =============================================================================
# LLM 相关方法
# =============================================================================
@staticmethod
def is_llm_enabled_for_session(session_id: str) -> bool:
"""检查LLM是否在指定会话中启用
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
# 如果配置了该会话的LLM状态返回该状态
llm_enabled = session_services.get("llm_enabled")
if llm_enabled is not None:
return llm_enabled
# 如果没有配置,默认为启用(兼容性考虑)
return True
@staticmethod
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
"""设置LLM在指定会话中的启停状态
Args:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置LLM状态
session_config[session_id]["llm_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
)
@staticmethod
def should_process_llm_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理LLM请求
Args:
event: 消息事件
Returns:
bool: True表示应该处理False表示跳过
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_llm_enabled_for_session(session_id)
# =============================================================================
# TTS 相关方法
# =============================================================================
@staticmethod
def is_tts_enabled_for_session(session_id: str) -> bool:
"""检查TTS是否在指定会话中启用
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
# 如果配置了该会话的TTS状态返回该状态
tts_enabled = session_services.get("tts_enabled")
if tts_enabled is not None:
return tts_enabled
# 如果没有配置,默认为启用(兼容性考虑)
return True
@staticmethod
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
"""设置TTS在指定会话中的启停状态
Args:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置TTS状态
session_config[session_id]["tts_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
)
@staticmethod
def should_process_tts_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理TTS请求
Args:
event: 消息事件
Returns:
bool: True表示应该处理False表示跳过
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_tts_enabled_for_session(session_id)
# =============================================================================
# 会话整体启停相关方法
# =============================================================================
@staticmethod
def is_session_enabled(session_id: str) -> bool:
"""检查会话是否整体启用
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
bool: True表示启用False表示禁用
"""
# 获取会话服务配置
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
# 如果配置了该会话的整体状态,返回该状态
session_enabled = session_services.get("session_enabled")
if session_enabled is not None:
return session_enabled
# 如果没有配置,默认为启用(兼容性考虑)
return True
@staticmethod
def set_session_status(session_id: str, enabled: bool) -> None:
"""设置会话的整体启停状态
Args:
session_id: 会话ID (unified_msg_origin)
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置会话整体状态
session_config[session_id]["session_enabled"] = enabled
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
)
@staticmethod
def should_process_session_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理会话请求(会话整体启停检查)
Args:
event: 消息事件
Returns:
bool: True表示应该处理False表示跳过
"""
session_id = event.unified_msg_origin
return SessionServiceManager.is_session_enabled(session_id)
# =============================================================================
# 会话命名相关方法
# =============================================================================
@staticmethod
def get_session_custom_name(session_id: str) -> str:
"""获取会话的自定义名称
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
str: 自定义名称如果没有设置则返回None
"""
session_config = sp.get("session_service_config", {}) or {}
session_services = session_config.get(session_id, {})
return session_services.get("custom_name")
@staticmethod
def set_session_custom_name(session_id: str, custom_name: str) -> None:
"""设置会话的自定义名称
Args:
session_id: 会话ID (unified_msg_origin)
custom_name: 自定义名称,可以为空字符串来清除名称
"""
# 获取当前配置
session_config = sp.get("session_service_config", {}) or {}
if session_id not in session_config:
session_config[session_id] = {}
# 设置自定义名称
if custom_name and custom_name.strip():
session_config[session_id]["custom_name"] = custom_name.strip()
else:
# 如果传入空名称,则删除自定义名称
session_config[session_id].pop("custom_name", None)
# 保存配置
sp.put("session_service_config", session_config)
logger.info(
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
)
@staticmethod
def get_session_display_name(session_id: str) -> str:
"""获取会话的显示名称优先显示自定义名称否则显示原始session_id的最后一段
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
str: 显示名称
"""
custom_name = SessionServiceManager.get_session_custom_name(session_id)
if custom_name:
return custom_name
# 如果没有自定义名称返回session_id的最后一段
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
# =============================================================================
# 通用配置方法
# =============================================================================
@staticmethod
def get_session_service_config(session_id: str) -> Dict[str, bool]:
"""获取指定会话的服务配置
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典
"""
session_config = sp.get("session_service_config", {}) or {}
return session_config.get(
session_id,
{
"session_enabled": True, # 默认启用
"llm_enabled": True, # 默认启用
"tts_enabled": True, # 默认启用
},
)
@staticmethod
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
"""获取所有会话的服务配置
Returns:
Dict[str, Dict[str, bool]]: 所有会话的服务配置
"""
return sp.get("session_service_config", {}) or {}

View File

@@ -0,0 +1,142 @@
"""
会话插件管理器 - 负责管理每个会话的插件启停状态
"""
from astrbot.core import sp, logger
from typing import Dict, List
from astrbot.core.platform.astr_message_event import AstrMessageEvent
class SessionPluginManager:
"""管理会话级别的插件启停状态"""
@staticmethod
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
"""检查插件是否在指定会话中启用
Args:
session_id: 会话ID (unified_msg_origin)
plugin_name: 插件名称
Returns:
bool: True表示启用False表示禁用
"""
# 获取会话插件配置
session_plugin_config = sp.get("session_plugin_config", {}) or {}
session_config = session_plugin_config.get(session_id, {})
enabled_plugins = session_config.get("enabled_plugins", [])
disabled_plugins = session_config.get("disabled_plugins", [])
# 如果插件在禁用列表中返回False
if plugin_name in disabled_plugins:
return False
# 如果插件在启用列表中返回True
if plugin_name in enabled_plugins:
return True
# 如果都没有配置,默认为启用(兼容性考虑)
return True
@staticmethod
def set_plugin_status_for_session(
session_id: str, plugin_name: str, enabled: bool
) -> None:
"""设置插件在指定会话中的启停状态
Args:
session_id: 会话ID (unified_msg_origin)
plugin_name: 插件名称
enabled: True表示启用False表示禁用
"""
# 获取当前配置
session_plugin_config = sp.get("session_plugin_config", {}) or {}
if session_id not in session_plugin_config:
session_plugin_config[session_id] = {
"enabled_plugins": [],
"disabled_plugins": [],
}
session_config = session_plugin_config[session_id]
enabled_plugins = session_config.get("enabled_plugins", [])
disabled_plugins = session_config.get("disabled_plugins", [])
if enabled:
# 启用插件
if plugin_name in disabled_plugins:
disabled_plugins.remove(plugin_name)
if plugin_name not in enabled_plugins:
enabled_plugins.append(plugin_name)
else:
# 禁用插件
if plugin_name in enabled_plugins:
enabled_plugins.remove(plugin_name)
if plugin_name not in disabled_plugins:
disabled_plugins.append(plugin_name)
# 保存配置
session_config["enabled_plugins"] = enabled_plugins
session_config["disabled_plugins"] = disabled_plugins
session_plugin_config[session_id] = session_config
sp.put("session_plugin_config", session_plugin_config)
logger.info(
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
)
@staticmethod
def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]:
"""获取指定会话的插件配置
Args:
session_id: 会话ID (unified_msg_origin)
Returns:
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
"""
session_plugin_config = sp.get("session_plugin_config", {}) or {}
return session_plugin_config.get(
session_id, {"enabled_plugins": [], "disabled_plugins": []}
)
@staticmethod
def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List:
"""根据会话配置过滤处理器列表
Args:
event: 消息事件
handlers: 原始处理器列表
Returns:
List: 过滤后的处理器列表
"""
from astrbot.core.star.star import star_map
session_id = event.unified_msg_origin
filtered_handlers = []
for handler in handlers:
# 获取处理器对应的插件
plugin = star_map.get(handler.handler_module_path)
if not plugin:
# 如果找不到插件元数据,允许执行(可能是系统插件)
filtered_handlers.append(handler)
continue
# 跳过保留插件(系统插件)
if plugin.reserved:
filtered_handlers.append(handler)
continue
# 检查插件是否在当前会话中启用
if SessionPluginManager.is_plugin_enabled_for_session(
session_id, plugin.name
):
filtered_handlers.append(handler)
else:
logger.debug(
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}"
)
return filtered_handlers

View File

@@ -1,14 +1,18 @@
from __future__ import annotations from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import ModuleType
from typing import TYPE_CHECKING
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
star_registry: List[StarMetadata] = [] star_registry: list[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {} star_map: dict[str, StarMetadata] = {}
"""key 是模块路径__module__""" """key 是模块路径__module__"""
if TYPE_CHECKING:
from . import Star
@dataclass @dataclass
class StarMetadata: class StarMetadata:
@@ -18,22 +22,27 @@ class StarMetadata:
当 activated 为 False 时star_cls 可能为 None请不要在插件未激活时调用 star_cls 的方法。 当 activated 为 False 时star_cls 可能为 None请不要在插件未激活时调用 star_cls 的方法。
""" """
name: str name: str | None = None
author: str # 插件作者 """插件名"""
desc: str # 插件简介 author: str | None = None
version: str # 插件版本 """插件作者"""
repo: str = None # 插件仓库地址 desc: str | None = None
"""插件简介"""
version: str | None = None
"""插件版本"""
repo: str | None = None
"""插件仓库地址"""
star_cls_type: type = None star_cls_type: type[Star] | None = None
"""插件的类对象的类型""" """插件的类对象的类型"""
module_path: str = None module_path: str | None = None
"""插件的模块路径""" """插件的模块路径"""
star_cls: object = None star_cls: Star | None = None
"""插件的类对象""" """插件的类对象"""
module: ModuleType = None module: ModuleType | None = None
"""插件的模块对象""" """插件的模块对象"""
root_dir_name: str = None root_dir_name: str | None = None
"""插件的目录名称""" """插件的目录名称"""
reserved: bool = False reserved: bool = False
"""是否是 AstrBot 的保留插件""" """是否是 AstrBot 的保留插件"""
@@ -41,17 +50,20 @@ class StarMetadata:
activated: bool = True activated: bool = True
"""是否被激活""" """是否被激活"""
config: AstrBotConfig = None config: AstrBotConfig | None = None
"""插件配置""" """插件配置"""
star_handler_full_names: List[str] = field(default_factory=list) star_handler_full_names: list[str] = field(default_factory=list)
"""注册的 Handler 的全名列表""" """注册的 Handler 的全名列表"""
supported_platforms: Dict[str, bool] = field(default_factory=dict) supported_platforms: dict[str, bool] = field(default_factory=dict)
"""插件支持的平台ID字典key为平台IDvalue为是否支持""" """插件支持的平台ID字典key为平台IDvalue为是否支持"""
def __str__(self) -> str: def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})" return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
def __repr__(self) -> str:
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
def update_platform_compatibility(self, plugin_enable_config: dict) -> None: def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
"""更新插件支持的平台列表 """更新插件支持的平台列表

View File

@@ -7,6 +7,7 @@ from .star import star_map
T = TypeVar("T", bound="StarHandlerMetadata") T = TypeVar("T", bound="StarHandlerMetadata")
class StarHandlerRegistry(Generic[T]): class StarHandlerRegistry(Generic[T]):
def __init__(self): def __init__(self):
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {} self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -49,7 +50,8 @@ class StarHandlerRegistry(Generic[T]):
self, module_name: str self, module_name: str
) -> List[StarHandlerMetadata]: ) -> List[StarHandlerMetadata]:
return [ return [
handler for handler in self._handlers handler
for handler in self._handlers
if handler.handler_module_path == module_name if handler.handler_module_path == module_name
] ]
@@ -67,6 +69,7 @@ class StarHandlerRegistry(Generic[T]):
def __len__(self): def __len__(self):
return len(self._handlers) return len(self._handlers)
star_handlers_registry = StarHandlerRegistry() star_handlers_registry = StarHandlerRegistry()

View File

@@ -11,7 +11,6 @@ import os
import sys import sys
import traceback import traceback
from types import ModuleType from types import ModuleType
from typing import List
import yaml import yaml
@@ -37,12 +36,6 @@ except ImportError:
if os.getenv("ASTRBOT_RELOAD", "0") == "1": if os.getenv("ASTRBOT_RELOAD", "0") == "1":
logger.warning("未安装 watchfiles无法实现插件的热重载。") logger.warning("未安装 watchfiles无法实现插件的热重载。")
try:
import nh3
except ImportError:
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
nh3 = None
class PluginManager: class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig): def __init__(self, context: Context, config: AstrBotConfig):
@@ -64,6 +57,8 @@ class PluginManager:
"""保留插件的路径。在 packages 目录下""" """保留插件的路径。在 packages 目录下"""
self.conf_schema_fname = "_conf_schema.json" self.conf_schema_fname = "_conf_schema.json"
"""插件配置 Schema 文件名""" """插件配置 Schema 文件名"""
self._pm_lock = asyncio.Lock()
"""StarManager操作互斥锁"""
self.failed_plugin_info = "" self.failed_plugin_info = ""
if os.getenv("ASTRBOT_RELOAD", "0") == "1": if os.getenv("ASTRBOT_RELOAD", "0") == "1":
@@ -119,7 +114,8 @@ class PluginManager:
reloaded_plugins.add(plugin_name) reloaded_plugins.add(plugin_name)
break break
def _get_classes(self, arg: ModuleType): @staticmethod
def _get_classes(arg: ModuleType):
"""获取指定模块(可以理解为一个 python 文件)下所有的类""" """获取指定模块(可以理解为一个 python 文件)下所有的类"""
classes = [] classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass) clsmembers = inspect.getmembers(arg, inspect.isclass)
@@ -129,7 +125,8 @@ class PluginManager:
break break
return classes return classes
def _get_modules(self, path): @staticmethod
def _get_modules(path):
modules = [] modules = []
dirs = os.listdir(path) dirs = os.listdir(path)
@@ -155,7 +152,7 @@ class PluginManager:
) )
return modules return modules
def _get_plugin_modules(self) -> List[dict]: def _get_plugin_modules(self) -> list[dict]:
plugins = [] plugins = []
if os.path.exists(self.plugin_store_path): if os.path.exists(self.plugin_store_path):
plugins.extend(self._get_modules(self.plugin_store_path)) plugins.extend(self._get_modules(self.plugin_store_path))
@@ -166,7 +163,7 @@ class PluginManager:
plugins.extend(_p) plugins.extend(_p)
return plugins return plugins
async def _check_plugin_dept_update(self, target_plugin: str = None): async def _check_plugin_dept_update(self, target_plugin: str | None = None):
"""检查插件的依赖 """检查插件的依赖
如果 target_plugin 为 None则检查所有插件的依赖 如果 target_plugin 为 None则检查所有插件的依赖
""" """
@@ -189,10 +186,11 @@ class PluginManager:
except Exception as e: except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _load_plugin_metadata(self, plugin_path: str, plugin_obj=None) -> StarMetadata: @staticmethod
"""v3.4.0 以前的方式载入插件元数据 def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。 Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数获取元数据。
""" """
metadata = None metadata = None
@@ -204,11 +202,14 @@ class PluginManager:
os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8" os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8"
) as f: ) as f:
metadata = yaml.safe_load(f) metadata = yaml.safe_load(f)
elif plugin_obj: elif plugin_obj and hasattr(plugin_obj, "info"):
# 使用 info() 函数 # 使用 info() 函数
metadata = plugin_obj.info() metadata = plugin_obj.info()
if isinstance(metadata, dict): if isinstance(metadata, dict):
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
if ( if (
"name" not in metadata "name" not in metadata
or "desc" not in metadata or "desc" not in metadata
@@ -228,8 +229,9 @@ class PluginManager:
return metadata return metadata
@staticmethod
def _get_plugin_related_modules( def _get_plugin_related_modules(
self, plugin_root_dir: str, is_reserved: bool = False plugin_root_dir: str, is_reserved: bool = False
) -> list[str]: ) -> list[str]:
"""获取与指定插件相关的所有已加载模块名 """获取与指定插件相关的所有已加载模块名
@@ -251,8 +253,8 @@ class PluginManager:
def _purge_modules( def _purge_modules(
self, self,
module_patterns: list[str] = None, module_patterns: list[str] | None = None,
root_dir_name: str = None, root_dir_name: str | None = None,
is_reserved: bool = False, is_reserved: bool = False,
): ):
"""从 sys.modules 中移除指定的模块 """从 sys.modules 中移除指定的模块
@@ -293,50 +295,51 @@ class PluginManager:
- success (bool): 重载是否成功 - success (bool): 重载是否成功
- error_message (str|None): 错误信息,成功时为 None - error_message (str|None): 错误信息,成功时为 None
""" """
specified_module_path = None async with self._pm_lock:
if specified_plugin_name: specified_module_path = None
for smd in star_registry: if specified_plugin_name:
if smd.name == specified_plugin_name: for smd in star_registry:
specified_module_path = smd.module_path if smd.name == specified_plugin_name:
break specified_module_path = smd.module_path
break
# 终止插件 # 终止插件
if not specified_module_path: if not specified_module_path:
# 重载所有插件 # 重载所有插件
for smd in star_registry: for smd in star_registry:
try: try:
await self._terminate_plugin(smd) await self._terminate_plugin(smd)
except Exception as e: except Exception as e:
logger.warning(traceback.format_exc()) logger.warning(traceback.format_exc())
logger.warning( logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
) )
if smd.name and smd.module_path:
await self._unbind_plugin(smd.name, smd.module_path)
await self._unbind_plugin(smd.name, smd.module_path) star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
if smd.name:
await self._unbind_plugin(smd.name, specified_module_path)
star_handlers_registry.clear() result = await self.load(specified_module_path)
star_map.clear()
star_registry.clear()
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
try:
await self._terminate_plugin(smd)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
await self._unbind_plugin(smd.name, specified_module_path) # 更新所有插件的平台兼容性
await self.update_all_platform_compatibility()
result = await self.load(specified_module_path) return result
# 更新所有插件的平台兼容性
await self.update_all_platform_compatibility()
return result
async def update_all_platform_compatibility(self): async def update_all_platform_compatibility(self):
"""更新所有插件的平台兼容性设置""" """更新所有插件的平台兼容性设置"""
@@ -435,7 +438,7 @@ class PluginManager:
) )
if path in star_map: if path in star_map:
# 通过装饰器的方式注册插件 # 通过 __init__subclass__ 注册插件
metadata = star_map[path] metadata = star_map[path]
try: try:
@@ -449,13 +452,15 @@ class PluginManager:
metadata.desc = metadata_yaml.desc metadata.desc = metadata_yaml.desc
metadata.version = metadata_yaml.version metadata.version = metadata_yaml.version
metadata.repo = metadata_yaml.repo metadata.repo = metadata_yaml.repo
except Exception: except Exception as e:
pass logger.warning(
f"插件 {root_dir_name} 元数据载入失败: {str(e)}。使用默认元数据。"
)
logger.info(metadata)
metadata.config = plugin_config metadata.config = plugin_config
if path not in inactivated_plugins: if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类 # 只有没有禁用插件时才实例化插件类
if plugin_config: if plugin_config and metadata.star_cls_type:
# metadata.config = plugin_config
try: try:
metadata.star_cls = metadata.star_cls_type( metadata.star_cls = metadata.star_cls_type(
context=self.context, config=plugin_config context=self.context, config=plugin_config
@@ -464,7 +469,7 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type( metadata.star_cls = metadata.star_cls_type(
context=self.context context=self.context
) )
else: elif metadata.star_cls_type:
metadata.star_cls = metadata.star_cls_type( metadata.star_cls = metadata.star_cls_type(
context=self.context context=self.context
) )
@@ -481,6 +486,10 @@ class PluginManager:
) )
metadata.update_platform_compatibility(plugin_enable_config) metadata.update_platform_compatibility(plugin_enable_config)
assert metadata.module_path is not None, (
f"插件 {metadata.name} 的模块路径为空。"
)
# 绑定 handler # 绑定 handler
related_handlers = ( related_handlers = (
star_handlers_registry.get_handlers_by_module_name( star_handlers_registry.get_handlers_by_module_name(
@@ -489,7 +498,8 @@ class PluginManager:
) )
for handler in related_handlers: for handler in related_handlers:
handler.handler = functools.partial( handler.handler = functools.partial(
handler.handler, metadata.star_cls handler.handler,
metadata.star_cls, # type: ignore
) )
# 绑定 llm_tool handler # 绑定 llm_tool handler
for func_tool in llm_tools.func_list: for func_tool in llm_tools.func_list:
@@ -499,7 +509,8 @@ class PluginManager:
): ):
func_tool.handler_module_path = metadata.module_path func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial( func_tool.handler = functools.partial(
func_tool.handler, metadata.star_cls func_tool.handler,
metadata.star_cls, # type: ignore
) )
if func_tool.name in inactivated_llm_tools: if func_tool.name in inactivated_llm_tools:
func_tool.active = False func_tool.active = False
@@ -526,13 +537,12 @@ class PluginManager:
obj = getattr(module, classes[0])( obj = getattr(module, classes[0])(
context=self.context context=self.context
) # 实例化插件类 ) # 实例化插件类
else:
logger.info(f"插件 {metadata.name} 已被禁用。")
metadata = None
metadata = self._load_plugin_metadata( metadata = self._load_plugin_metadata(
plugin_path=plugin_dir_path, plugin_obj=obj plugin_path=plugin_dir_path, plugin_obj=obj
) )
if not metadata:
raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
metadata.star_cls = obj metadata.star_cls = obj
metadata.config = plugin_config metadata.config = plugin_config
metadata.module = module metadata.module = module
@@ -547,6 +557,10 @@ class PluginManager:
if metadata.module_path in inactivated_plugins: if metadata.module_path in inactivated_plugins:
metadata.activated = False metadata.activated = False
assert metadata.module_path is not None, (
f"插件 {metadata.name} 的模块路径为空。"
)
full_names = [] full_names = []
for handler in star_handlers_registry.get_handlers_by_module_name( for handler in star_handlers_registry.get_handlers_by_module_name(
metadata.module_path metadata.module_path
@@ -586,7 +600,7 @@ class PluginManager:
metadata.star_handler_full_names = full_names metadata.star_handler_full_names = full_names
# 执行 initialize() 方法 # 执行 initialize() 方法
if hasattr(metadata.star_cls, "initialize"): if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
await metadata.star_cls.initialize() await metadata.star_cls.initialize()
except BaseException as e: except BaseException as e:
@@ -622,43 +636,45 @@ class PluginManager:
- readme: README.md 文件的内容(如果存在) - readme: README.md 文件的内容(如果存在)
如果找不到插件元数据则返回 None。 如果找不到插件元数据则返回 None。
""" """
plugin_path = await self.updator.install(repo_url, proxy) async with self._pm_lock:
# reload the plugin plugin_path = await self.updator.install(repo_url, proxy)
dir_name = os.path.basename(plugin_path) # reload the plugin
await self.load(specified_dir_name=dir_name) dir_name = os.path.basename(plugin_path)
await self.load(specified_dir_name=dir_name)
# Get the plugin metadata to return repo info # Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name) plugin = self.context.get_registered_star(dir_name)
if not plugin: if not plugin:
# Try to find by other name if directory name doesn't match plugin name # Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars(): for star in self.context.get_all_stars():
if star.root_dir_name == dir_name: if star.root_dir_name == dir_name:
plugin = star plugin = star
break break
# Extract README.md content if exists # Extract README.md content if exists
readme_content = None readme_content = None
readme_path = os.path.join(plugin_path, "README.md") readme_path = os.path.join(plugin_path, "README.md")
if not os.path.exists(readme_path): if not os.path.exists(readme_path):
readme_path = os.path.join(plugin_path, "readme.md") readme_path = os.path.join(plugin_path, "readme.md")
if os.path.exists(readme_path) and nh3: if os.path.exists(readme_path):
try: try:
with open(readme_path, "r", encoding="utf-8") as f: with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read() readme_content = f.read()
cleaned_content = nh3.clean(readme_content) except Exception as e:
except Exception as e: logger.warning(
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}") f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}"
)
plugin_info = None plugin_info = None
if plugin: if plugin:
plugin_info = { plugin_info = {
"repo": plugin.repo, "repo": plugin.repo,
"readme": cleaned_content, "readme": readme_content,
"name": plugin.name, "name": plugin.name,
} }
return plugin_info return plugin_info
async def uninstall_plugin(self, plugin_name: str): async def uninstall_plugin(self, plugin_name: str):
"""卸载指定的插件。 """卸载指定的插件。
@@ -669,32 +685,33 @@ class PluginManager:
Raises: Raises:
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常 Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
""" """
plugin = self.context.get_registered_star(plugin_name) async with self._pm_lock:
if not plugin: plugin = self.context.get_registered_star(plugin_name)
raise Exception("插件不存在。") if not plugin:
if plugin.reserved: raise Exception("插件不存在。")
raise Exception("该插件是 AstrBot 保留插件,无法卸载。") if plugin.reserved:
root_dir_name = plugin.root_dir_name raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
ppath = self.plugin_store_path root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
# 终止插件 # 终止插件
try: try:
await self._terminate_plugin(plugin) await self._terminate_plugin(plugin)
except Exception as e: except Exception as e:
logger.warning(traceback.format_exc()) logger.warning(traceback.format_exc())
logger.warning( logger.warning(
f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。" f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。"
) )
# 从 star_registry 和 star_map 中删除 # 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path) await self._unbind_plugin(plugin_name, plugin.module_path)
try: try:
remove_dir(os.path.join(ppath, root_dir_name)) remove_dir(os.path.join(ppath, root_dir_name))
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。" f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
) )
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
"""解绑并移除一个插件。 """解绑并移除一个插件。
@@ -725,6 +742,9 @@ class PluginManager:
]: ]:
del star_handlers_registry.star_handlers_map[k] del star_handlers_registry.star_handlers_map[k]
if plugin is None:
return
self._purge_modules( self._purge_modules(
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
) )
@@ -747,35 +767,37 @@ class PluginManager:
将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。 将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。
并且同时将插件启用的 llm_tool 禁用。 并且同时将插件启用的 llm_tool 禁用。
""" """
plugin = self.context.get_registered_star(plugin_name) async with self._pm_lock:
if not plugin: plugin = self.context.get_registered_star(plugin_name)
raise Exception("插件不存在。") if not plugin:
raise Exception("插件不存在。")
# 调用插件的终止方法 # 调用插件的终止方法
await self._terminate_plugin(plugin) await self._terminate_plugin(plugin)
# 加入到 shared_preferences 中 # 加入到 shared_preferences 中
inactivated_plugins: list = sp.get("inactivated_plugins", []) inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins: if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path) inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = list( inactivated_llm_tools: list = list(
set(sp.get("inactivated_llm_tools", [])) set(sp.get("inactivated_llm_tools", []))
) # 后向兼容 ) # 后向兼容
# 禁用插件启用的 llm_tool # 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list: for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path: if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False func_tool.active = False
if func_tool.name not in inactivated_llm_tools: if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name) inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins) sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools) sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False plugin.activated = False
async def _terminate_plugin(self, star_metadata: StarMetadata): @staticmethod
async def _terminate_plugin(star_metadata: StarMetadata):
"""终止插件,调用插件的 terminate() 和 __del__() 方法""" """终止插件,调用插件的 terminate() 和 __del__() 方法"""
logger.info(f"正在终止插件 {star_metadata.name} ...") logger.info(f"正在终止插件 {star_metadata.name} ...")
@@ -784,11 +806,14 @@ class PluginManager:
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。") logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
return return
if hasattr(star_metadata.star_cls, "__del__"): if star_metadata.star_cls is None:
return
if "__del__" in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor( asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__ None, star_metadata.star_cls.__del__
) )
elif hasattr(star_metadata.star_cls, "terminate"): elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate() await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str): async def turn_on_plugin(self, plugin_name: str):

View File

@@ -182,7 +182,9 @@ class StarTools:
plugin_name = metadata.name plugin_name = metadata.name
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)) data_dir = Path(
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
)
try: try:
data_dir.mkdir(parents=True, exist_ok=True) data_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -56,9 +56,7 @@ class AstrBotUpdator(RepoZipUpdator):
try: try:
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
if os.name == "nt": if os.name == "nt":
args = [ args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
]
else: else:
args = sys.argv[1:] args = sys.argv[1:]
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args) os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)

View File

@@ -30,7 +30,7 @@ def on_error(func, path, exc_info):
raise exc_info[1] raise exc_info[1]
def remove_dir(file_path) -> bool: def remove_dir(file_path: str) -> bool:
if not os.path.exists(file_path): if not os.path.exists(file_path):
return True return True
shutil.rmtree(file_path, onerror=on_error) shutil.rmtree(file_path, onerror=on_error)

View File

@@ -0,0 +1,29 @@
import asyncio
from collections import defaultdict
from contextlib import asynccontextmanager
class SessionLockManager:
def __init__(self):
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._lock_count: dict[str, int] = defaultdict(int)
self._access_lock = asyncio.Lock()
@asynccontextmanager
async def acquire_lock(self, session_id: str):
async with self._access_lock:
lock = self._locks[session_id]
self._lock_count[session_id] += 1
try:
async with lock:
yield
finally:
async with self._access_lock:
self._lock_count[session_id] -= 1
if self._lock_count[session_id] == 0:
self._locks.pop(session_id, None)
self._lock_count.pop(session_id, None)
session_lock_manager = SessionLockManager()

View File

@@ -1,7 +1,10 @@
import json import json
import os import os
from typing import TypeVar
from .astrbot_path import get_astrbot_data_path from .astrbot_path import get_astrbot_data_path
_VT = TypeVar("_VT")
class SharedPreferences: class SharedPreferences:
def __init__(self, path=None): def __init__(self, path=None):
@@ -24,7 +27,7 @@ class SharedPreferences:
json.dump(self._data, f, indent=4, ensure_ascii=False) json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush() f.flush()
def get(self, key, default=None): def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default) return self._data.get(key, default)
def put(self, key, value): def put(self, key, value):

View File

@@ -11,7 +11,7 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
class NetworkRenderStrategy(RenderStrategy): class NetworkRenderStrategy(RenderStrategy):
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None: def __init__(self, base_url: str | None = None) -> None:
super().__init__() super().__init__()
if not base_url: if not base_url:
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
@@ -34,18 +34,22 @@ class NetworkRenderStrategy(RenderStrategy):
self.BASE_RENDER_URL += "/text2img" self.BASE_RENDER_URL += "/text2img"
async def render_custom_template( async def render_custom_template(
self, tmpl_str: str, tmpl_data: dict, return_url: bool = True self,
tmpl_str: str,
tmpl_data: dict,
return_url: bool = True,
options: dict | None = None,
) -> str: ) -> str:
"""使用自定义文转图模板""" """使用自定义文转图模板"""
default_options = {"full_page": True, "type": "jpeg", "quality": 40}
if options:
default_options |= options
post_data = { post_data = {
"tmpl": tmpl_str, "tmpl": tmpl_str,
"json": return_url, "json": return_url,
"tmpldata": tmpl_data, "tmpldata": tmpl_data,
"options": { "options": default_options,
"full_page": True,
"type": "jpeg",
"quality": 40,
},
} }
if return_url: if return_url:
ssl_context = ssl.create_default_context(cafile=certifi.where()) ssl_context = ssl.create_default_context(cafile=certifi.where())

View File

@@ -6,7 +6,7 @@ logger = LogManager.GetLogger(log_name="astrbot")
class HtmlRenderer: class HtmlRenderer:
def __init__(self, endpoint_url: str = None): def __init__(self, endpoint_url: str | None = None):
self.network_strategy = NetworkRenderStrategy(endpoint_url) self.network_strategy = NetworkRenderStrategy(endpoint_url)
self.local_strategy = LocalRenderStrategy() self.local_strategy = LocalRenderStrategy()
@@ -16,19 +16,24 @@ class HtmlRenderer:
self.network_strategy.set_endpoint(endpoint_url) self.network_strategy.set_endpoint(endpoint_url)
async def render_custom_template( async def render_custom_template(
self, tmpl_str: str, tmpl_data: dict, return_url: bool = False self,
tmpl_str: str,
tmpl_data: dict,
return_url: bool = False,
options: dict | None = None,
): ):
"""使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。 """使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
@param tmpl_str: HTML Jinja2 模板。 @param tmpl_str: HTML Jinja2 模板。
@param tmpl_data: jinja2 模板数据。 @param tmpl_data: jinja2 模板数据。
@param options: 渲染选项。
@return: 图片 URL 或者文件路径,取决于 return_url 参数。 @return: 图片 URL 或者文件路径,取决于 return_url 参数。
@example: 参见 https://astrbot.app 插件开发部分。 @example: 参见 https://astrbot.app 插件开发部分。
""" """
local = locals() return await self.network_strategy.render_custom_template(
local.pop("self") tmpl_str, tmpl_data, return_url, options
return await self.network_strategy.render_custom_template(**local) )
async def render_t2i( async def render_t2i(
self, text: str, use_network: bool = True, return_url: bool = False self, text: str, use_network: bool = True, return_url: bool = False

View File

@@ -1,9 +1,11 @@
import base64 import base64
import wave import wave
import os import os
import subprocess
from io import BytesIO from io import BytesIO
import asyncio import asyncio
import tempfile import tempfile
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -57,33 +59,89 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
return duration return duration
async def wav_to_tencent_silk_base64(wav_path: str) -> str: async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
""" """
WAV 文件转为 Silk并返回 Base64 字符串 MP3 或其他音频格式转换为 PCM 16bit WAV采样率24000Hz单声道
默认采样率为 24000输出临时文件为 temp/output.silk 若转换失败则抛出异常
"""
try:
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input=input_path, output=output_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
p = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-i",
input_path,
"-acodec",
"pcm_s16le",
"-ar",
"24000",
"-ac",
"1",
"-af",
"apad=pad_dur=2",
"-fflags",
"+genpts",
"-hide_banner",
output_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = await p.communicate()
logger.info(f"[FFmpeg] stdout: {stdout.decode().strip()}")
logger.debug(f"[FFmpeg] stderr: {stderr.decode().strip()}")
logger.info(f"[FFmpeg] return code: {p.returncode}")
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
return output_path
else:
raise RuntimeError("生成的WAV文件不存在或为空")
async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
"""
将 MP3/WAV 文件转为 Tencent Silk 并返回 base64 编码与时长(秒)。
参数: 参数:
- wav_path: 输入 .wav 文件路径(需为 PCM 16bit - audio_path: 输入音频文件路径(.mp3 或 .wav
返回: 返回:
- Base64 编码的 Silk 字符串 - silk_b64: Base64 编码的 Silk 字符串
- duration: 音频时长(秒) - duration: 音频时长(秒)
""" """
try: try:
import pilk import pilk
except ImportError as e: except ImportError as e:
raise Exception("pysilk 模块未安装,请安装 pysilk") from e raise Exception("未安装 pilk: pip install pilk") from e
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True) os.makedirs(temp_dir, exist_ok=True)
with wave.open(wav_path, "rb") as wav: # 是否需要转换为 WAV
rate = wav.getframerate() ext = os.path.splitext(audio_path)[1].lower()
temp_wav = tempfile.NamedTemporaryFile(
suffix=".wav", delete=False, dir=temp_dir
).name
with tempfile.NamedTemporaryFile( if ext != ".wav":
await convert_to_pcm_wav(audio_path, temp_wav)
# 删除原文件
os.remove(audio_path)
wav_path = temp_wav
else:
wav_path = audio_path
with wave.open(wav_path, "rb") as wav_file:
rate = wav_file.getframerate()
silk_path = tempfile.NamedTemporaryFile(
suffix=".silk", delete=False, dir=temp_dir suffix=".silk", delete=False, dir=temp_dir
) as tmp_file: ).name
silk_path = tmp_file.name
try: try:
duration = await asyncio.to_thread( duration = await asyncio.to_thread(
@@ -96,5 +154,7 @@ async def wav_to_tencent_silk_base64(wav_path: str) -> str:
return silk_b64, duration # 已是秒 return silk_b64, duration # 已是秒
finally: finally:
if os.path.exists(wav_path) and wav_path != audio_path:
os.remove(wav_path)
if os.path.exists(silk_path): if os.path.exists(silk_path):
os.remove(silk_path) os.remove(silk_path)

View File

@@ -9,6 +9,7 @@ from .chat import ChatRoute
from .tools import ToolsRoute # 导入新的ToolsRoute from .tools import ToolsRoute # 导入新的ToolsRoute
from .conversation import ConversationRoute from .conversation import ConversationRoute
from .file import FileRoute from .file import FileRoute
from .session_management import SessionManagementRoute
__all__ = [ __all__ = [
@@ -23,4 +24,5 @@ __all__ = [
"ToolsRoute", "ToolsRoute",
"ConversationRoute", "ConversationRoute",
"FileRoute", "FileRoute",
"SessionManagementRoute",
] ]

View File

@@ -3,7 +3,7 @@ import datetime
import asyncio import asyncio
from .route import Route, Response, RouteContext from .route import Route, Response, RouteContext
from quart import request from quart import request
from astrbot.core import WEBUI_SK, DEMO_MODE from astrbot.core import DEMO_MODE
from astrbot import logger from astrbot import logger
@@ -80,5 +80,8 @@ class AuthRoute(Route):
"username": username, "username": username,
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7), "exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
} }
token = jwt.encode(payload, WEBUI_SK, algorithm="HS256") jwt_token = self.config["dashboard"].get("jwt_secret", None)
if not jwt_token:
raise ValueError("JWT secret is not set in the cmd_config.")
token = jwt.encode(payload, jwt_token, algorithm="HS256")
return token return token

View File

@@ -2,7 +2,7 @@ import uuid
import json import json
import os import os
from .route import Route, Response, RouteContext from .route import Route, Response, RouteContext
from astrbot.core import web_chat_queue, web_chat_back_queue from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from quart import request, Response as QuartResponse, g, make_response from quart import request, Response as QuartResponse, g, make_response
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
import asyncio import asyncio
@@ -21,7 +21,6 @@ class ChatRoute(Route):
super().__init__(context) super().__init__(context)
self.routes = { self.routes = {
"/chat/send": ("POST", self.chat), "/chat/send": ("POST", self.chat),
"/chat/listen": ("GET", self.listener),
"/chat/new_conversation": ("GET", self.new_conversation), "/chat/new_conversation": ("GET", self.new_conversation),
"/chat/conversations": ("GET", self.get_conversations), "/chat/conversations": ("GET", self.get_conversations),
"/chat/get_conversation": ("GET", self.get_conversation), "/chat/get_conversation": ("GET", self.get_conversation),
@@ -40,9 +39,6 @@ class ChatRoute(Route):
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
self.curr_user_cid = {}
self.curr_chat_sse = {}
async def status(self): async def status(self):
has_llm_enabled = ( has_llm_enabled = (
self.core_lifecycle.provider_manager.curr_provider_inst is not None self.core_lifecycle.provider_manager.curr_provider_inst is not None
@@ -124,6 +120,8 @@ class ChatRoute(Route):
conversation_id = post_data["conversation_id"] conversation_id = post_data["conversation_id"]
image_url = post_data.get("image_url") image_url = post_data.get("image_url")
audio_url = post_data.get("audio_url") audio_url = post_data.get("audio_url")
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
if not message and not image_url and not audio_url: if not message and not image_url and not audio_url:
return ( return (
Response() Response()
@@ -133,21 +131,10 @@ class ChatRoute(Route):
if not conversation_id: if not conversation_id:
return Response().error("conversation_id is empty").__dict__ return Response().error("conversation_id is empty").__dict__
self.curr_user_cid[username] = conversation_id # Get conversation-specific queues
back_queue = webchat_queue_mgr.get_or_create_back_queue(conversation_id)
await web_chat_queue.put( # append user message
(
username,
conversation_id,
{
"message": message,
"image_url": image_url, # list
"audio_url": audio_url,
},
)
)
# 持久化
conversation = self.db.get_conversation_by_user_id(username, conversation_id) conversation = self.db.get_conversation_by_user_id(username, conversation_id)
try: try:
history = json.loads(conversation.history) history = json.loads(conversation.history)
@@ -164,30 +151,12 @@ class ChatRoute(Route):
username, conversation_id, history=json.dumps(history) username, conversation_id, history=json.dumps(history)
) )
return Response().ok().__dict__
async def listener(self):
"""一直保持长连接"""
username = g.get("username", "guest")
if username in self.curr_chat_sse:
return Response().error("Already connected").__dict__
self.curr_chat_sse[username] = None
heartbeat = json.dumps({"type": "heartbeat", "data": "ping"})
async def stream(): async def stream():
try: try:
yield f"data: {heartbeat}\n\n" # 心跳包
while True: while True:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(back_queue.get(), timeout=10)
web_chat_back_queue.get(), timeout=10
) # 设置超时时间为5秒
except asyncio.TimeoutError: except asyncio.TimeoutError:
yield f"data: {heartbeat}\n\n" # 心跳包
continue continue
if not result: if not result:
@@ -197,19 +166,13 @@ class ChatRoute(Route):
type = result.get("type") type = result.get("type")
cid = result.get("cid") cid = result.get("cid")
streaming = result.get("streaming", False) streaming = result.get("streaming", False)
if cid != self.curr_user_cid.get(username):
# 丢弃
continue
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
if streaming and type != "end": if type == "end":
continue break
elif (streaming and type == "complete") or not streaming:
if type == "update_title": # append bot message
continue
if result_text:
conversation = self.db.get_conversation_by_user_id( conversation = self.db.get_conversation_by_user_id(
username, cid username, cid
) )
@@ -222,11 +185,27 @@ class ChatRoute(Route):
self.db.update_conversation( self.db.update_conversation(
username, cid, history=json.dumps(history) username, cid, history=json.dumps(history)
) )
except BaseException as _: except BaseException as _:
logger.debug(f"用户 {username} 断开聊天长连接。") logger.debug(f"用户 {username} 断开聊天长连接。")
self.curr_chat_sse.pop(username)
return return
# Put message to conversation-specific queue
chat_queue = webchat_queue_mgr.get_or_create_queue(conversation_id)
await chat_queue.put(
(
username,
conversation_id,
{
"message": message,
"image_url": image_url, # list
"audio_url": audio_url,
"selected_provider": selected_provider,
"selected_model": selected_model,
},
)
)
response = await make_response( response = await make_response(
stream(), stream(),
{ {
@@ -236,7 +215,6 @@ class ChatRoute(Route):
"Connection": "keep-alive", "Connection": "keep-alive",
}, },
) )
response.timeout = None
return response return response
async def delete_conversation(self): async def delete_conversation(self):
@@ -245,6 +223,8 @@ class ChatRoute(Route):
if not conversation_id: if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__ return Response().error("Missing key: conversation_id").__dict__
# Clean up queues when deleting conversation
webchat_queue_mgr.remove_queues(conversation_id)
self.db.delete_conversation(username, conversation_id) self.db.delete_conversation(username, conversation_id)
return Response().ok().__dict__ return Response().ok().__dict__
@@ -279,6 +259,4 @@ class ChatRoute(Route):
conversation = self.db.get_conversation_by_user_id(username, conversation_id) conversation = self.db.get_conversation_by_user_id(username, conversation_id)
self.curr_user_cid[username] = conversation_id
return Response().ok(data=conversation).__dict__ return Response().ok(data=conversation).__dict__

Some files were not shown because too many files have changed in this diff Show More