Compare commits

...

604 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
479a737a03 Implement googleSearch tool in openai_source.py for Gemini(OpenAI Compatible) provider
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-17 02:07:36 +00:00
copilot-swe-agent[bot]
e324b69bf1 Add googleSearch tool alias for OpenAI-compatible Gemini API
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:39:20 +00:00
copilot-swe-agent[bot]
df86913fbf Initial setup and analysis of googleSearch tool feature request
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-16 16:33:47 +00:00
copilot-swe-agent[bot]
ca77cd6a83 Initial plan 2025-08-16 16:27:25 +00:00
Soulter
02a9769b35 fix: 补充工具调用轮数上限配置 2025-08-14 23:51:22 +08:00
Soulter
7640f11bfc docs: update readme 2025-08-14 17:28:51 +08:00
Soulter
9fa44dbcfa docs: update readme 2025-08-14 14:16:22 +08:00
Copilot
2cae941bae Fix incomplete Gemini streaming responses in chat history (#2429)
* Initial plan

* Fix incomplete Gemini streaming responses in chat history

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-08-14 11:56:50 +08:00
Copilot
bc0784f41d fix: enable_thinking parameter for qwen3 models in non-streaming calls (#2424)
* Initial plan

* Fix ModelScope enable_thinking parameter for non-streaming calls

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Tighten enable_thinking condition to only Qwen/Qwen3 models

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* qwen3 model handle

* Update astrbot/core/provider/sources/openai_source.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-14 11:18:29 +08:00
Copilot
c57d75e01a feat: add comprehensive GitHub Copilot instructions for AstrBot development (#2426)
* Initial plan

* Initial progress - completed repository exploration and dependency installation

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Complete copilot-instructions.md with comprehensive development guide

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Update copilot-instructions.md

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-08-13 23:31:28 +08:00
RC-CHN
73edeae013 perf: 优化hint渲染方式,为部分类型供应商添加默认的温度选项 (#2321)
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑

* perf:优化hint渲染方式,为部分类型供应商添加默认的温度选项
2025-08-12 21:53:06 +08:00
MUKAPP
7d46314dc8 fix: 修复注册文件时由于 file:/// 前缀,导致文件被误判为不存在的问题 (#2325)
fixes #2222
2025-08-12 21:47:31 +08:00
你们的饺子
d5a53a89eb fix: 修复插件的 terminate 无法被正常调用的问题 (#2352) 2025-08-12 21:41:19 +08:00
dependabot[bot]
a85bc510dd chore(deps): bump actions/checkout in the github-actions group (#2400)
Bumps the github-actions group with 1 update: [actions/checkout](https://github.com/actions/checkout).


Updates `actions/checkout` from 4 to 5
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v5)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-12 15:15:28 +08:00
Soulter
2beea7d218 📦 release: v3.5.24 2025-08-07 20:36:59 +08:00
Soulter
a93cd3dd5f feat: compshare provider 2025-08-07 20:25:45 +08:00
Soulter
db4d02c2e2 docs: add 1panel deployment method 2025-08-02 19:01:49 +08:00
Soulter
fd7811402b fix: 添加对 metadata 中 description 字段的支持,确保元数据完整性
fixes: #2245
2025-08-02 16:01:10 +08:00
你们的饺子
eb0325e627 fix: 修复了 OpenAI 类型的 LLM 空内容响应导致的无法解析 completion 的错误。 (#2279) 2025-08-02 15:46:11 +08:00
IGCrystal
8b4b04ec09 fix(i18n): add missing noTemplates key (#2292) 2025-08-02 14:16:59 +08:00
Larch-C
9f32c9280f chore: update and rename PLUGIN_PUBLISH.md to PLUGIN_PUBLISH.yml (#2289) 2025-08-02 14:16:19 +08:00
yrk111222
4fcd09cfa8 feat: add ModelScope API support (#2230)
* add ModelScope API support

* update
2025-08-02 14:14:08 +08:00
Misaka Mikoto
7a8d65d37d feat: add plugins local cache and remote file MD5 validation (#2211)
* 修改openai的嵌入模型默认维度为1024

* 为插件市场添加本地缓存
- 优先使用api获取,获取失败时则使用本地缓存
- 每次获取后会更新本地缓存
- 如果获取结果为空,判定为获取失败,使用本地缓存
- 前端页面添加刷新按钮,用于手动刷新本地缓存

* feat: 增强插件市场缓存机制,支持MD5校验以确保数据有效性

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-08-02 14:03:53 +08:00
Raven95676
23129a9ba2 Merge branch 'releases/3.5.23' 2025-07-26 16:49:38 +08:00
Raven95676
7f791e730b fix: changelogs 2025-07-26 16:49:05 +08:00
Raven95676
f7e296b349 Merge branch 'releases/3.5.23' 2025-07-26 16:34:30 +08:00
Raven95676
712d4acaaa release: v3.5.23 2025-07-26 16:32:06 +08:00
Raven95676
74a5c01f21 refactor: remove code and documentation references related to gewechat 2025-07-26 14:19:17 +08:00
Raven95676
3ba8724d77 Merge branch 'master' into dev 2025-07-26 14:02:05 +08:00
鸦羽
6313a7d8a9 Merge pull request #2221 from Raven95676/fix/axios-dependency
fix: update axios version range for vulnerability fix
2025-07-24 18:34:14 +08:00
Raven95676
432a3f520c fix: update axios version range for vulnerability fix 2025-07-24 18:28:02 +08:00
Soulter
191b3e42d4 feat: implement log history retrieval and improve log streaming handling (#2190) 2025-07-23 23:36:08 +08:00
Misaka Mikoto
a27f05fcb4 chore: 修改 OpenAI 嵌入模型提供商默认向量维度为1024 (#2209) 2025-07-23 23:35:04 +08:00
Soulter
2f33e0b873 chore: remove adapters of wechat personal account 2025-07-23 10:51:42 +08:00
Soulter
f0359467f1 chore: remove adapters of wechat personal account 2025-07-23 10:50:43 +08:00
Soulter
d1db8cf2c8 chore: remove adapters of wechat personal account 2025-07-23 10:48:58 +08:00
Soulter
b1985ed2ce Merge branch 'dev' 2025-07-23 00:38:08 +08:00
Gao Jinzhe
140ddc70e6 feat: 使用会话锁保证分段回复时的消息发送顺序 (#2130)
* 优化分段消息发送逻辑,为分段消息添加消息队列

* 删除了不必要的代码

* style: code quality

* 将消息队列机制重构为会话锁机制

* perf: narrow the lock scope

* refactor: replace get_lock with async context manager for session locks

* refactor: optimize session lock management with defaultdict

---------

Co-authored-by: Soulter <905617992@qq.com>
Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-07-23 00:37:29 +08:00
Soulter
d7fd616470 style: code quality 2025-07-21 17:04:29 +08:00
Soulter
3ccbef141e perf: extension ui 2025-07-21 15:16:49 +08:00
Soulter
e92fbb0443 feat: add ProxySelector component for GitHub proxy configuration and connection testing (#2185) 2025-07-21 15:05:49 +08:00
Soulter
bd270aed68 fix: handle event construction errors in message reply processing 2025-07-20 22:52:14 +08:00
Soulter
28d7864393 perf: tool use page UI (#2182)
* perf: tool use UI

* fix: update background color of item cards in ToolUsePage
2025-07-20 20:24:03 +08:00
RC-CHN
b5d8173ee3 feat: add a file uplod button in WebChat page (#2136)
* feat:为webchat页面添加一个手动上传文件按钮(目前只处理图片)

* fix:上传后清空value,允许触发change事件以多次上传同一张图片

* perf:webchat页面消息发送后清空图片预览缩略图,维持与文本信息行为一致

* perf:将文件输入的值重置为空字符串以提升浏览器兼容性

* feat:webchat文件上传按钮支持多选文件上传

* fix:释放blob URL以防止内存泄漏

* perf:并行化sendMessage中的图片获取逻辑
2025-07-20 16:02:28 +08:00
Soulter
17d62a9af7 refactor: mcp server reload mechanism (#2161)
* refactor: mcp server reload mechanism

* fix: wait for client events

* fix: all other mcp servers are terminated when disable selected server

* fix: resolve type hinting issues in MCPClient and FuncCall methods

* perf: optimize mcp server loaders

* perf: improve MCP client connection testing

* perf: improve error message

* perf: clean code

* perf: increase default timeout for MCP connection and reset dialog message on close

---------

Co-authored-by: Raven95676 <Raven95676@gmail.com>
2025-07-20 15:53:13 +08:00
Soulter
d89fb863ed fix: improve logging and error message details in LLMRequestSubStage 2025-07-18 16:13:27 +08:00
Soulter
a21ad77820 Merge pull request #2146 from Raven95676/fix/mcp
fix: 修复MCP导致的持续占用100% CPU
2025-07-18 13:04:50 +08:00
Raven95676
f86c8e8cab perf: ensure MCP client termination in cleanup process 2025-07-17 23:17:23 +08:00
Raven95676
cb12cbdd3d fix: managing MCP connections with AsyncExitStack 2025-07-16 23:44:51 +08:00
Soulter
6661fa996c fix: audio block does not display 2025-07-14 22:20:03 +08:00
Soulter
c19bca798b fix: xfyun model tool use error workaround
fixes: #1359
2025-07-14 22:07:33 +08:00
Soulter
8f98b411db Merge pull request #2129 from AstrBotDevs/perf-refine-webui-chatpage
Improve: WebUI ChatPage markdown code block background
2025-07-14 21:49:18 +08:00
Soulter
a8aa03847e feat: enhance theme customization with new background properties and markdown styling 2025-07-14 21:47:25 +08:00
Soulter
1bfd747cc6 perf: add system_prompt to payload_vars in dify text_chat method 2025-07-14 11:00:13 +08:00
Soulter
ae06d945a7 Merge pull request #2054 from RC-CHN/master
Feature: Add provider_type field for ProviderMetadata and improve provider availabiliby test
2025-07-13 17:38:22 +08:00
Soulter
9f41d5f34d Merge remote-tracking branch 'origin/master' into RC-CHN/master 2025-07-13 17:35:53 +08:00
Soulter
ef61c52908 fix: remove non-existent Response field 2025-07-13 17:33:13 +08:00
Soulter
d8842ef274 perf: code quality 2025-07-13 17:27:40 +08:00
Soulter
c88fdaf353 Merge pull request #1949 from advent259141/Astrbot_session_manage
[Feature] 支持在 WebUI 上管理会话
2025-07-13 17:23:52 +08:00
Soulter
af295da871 chore: remove /mcp command 2025-07-13 17:11:02 +08:00
Soulter
083235a2fe feat: enhance session management page with tooltips and layout adjustments 2025-07-13 17:06:15 +08:00
Soulter
2a3a5f7eb2 perf: refine session management page ui 2025-07-13 16:57:36 +08:00
Soulter
77c48f280f fix: session management paginator error 2025-07-13 16:36:25 +08:00
Soulter
0ee1eb2f9f chore: remove useless file 2025-07-13 16:30:00 +08:00
Soulter
c2b20365bb Merge pull request #2097 from SheffeyG/fix-status-checking
Fix: add status checking for embedding model providers
2025-07-13 16:19:37 +08:00
Soulter
cfdc7e4452 fix: add debug logging for provider request handling in LLMRequestSubStage
fixes: #2104
2025-07-13 16:12:48 +08:00
Soulter
2363f61aa9 chore: remove 'obvious_hint' fields from configuration metadata and remove some deprecated config 2025-07-13 16:03:47 +08:00
Soulter
557ac6f9fa Merge pull request #2112 from AstrBotDevs/perf-provider-logo
Improve: WebUI provider logo display
2025-07-13 15:34:19 +08:00
Soulter
a49b871cf9 fix: update Azure provider icon URL in getProviderIcon method 2025-07-13 15:33:47 +08:00
Soulter
a0d6b3efba perf: improve provider logo display in webui 2025-07-13 15:27:53 +08:00
Gao Jinzhe
6cabf07bc0 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-07-13 00:23:29 +08:00
advent259141
a15444ee8c 移除了mcp会话级的启停,增加了批量设置的选项,对相关问题进行了修复 2025-07-13 00:15:21 +08:00
Soulter
ceb5f5669e fix: update active reply bot prefix in logging for clarity 2025-07-13 00:12:31 +08:00
Gao Jinzhe
25b75e05e4 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-07-12 22:25:20 +08:00
sheffey
4d214bb5c1 check general numbers type instead 2025-07-11 18:36:47 +08:00
sheffey
7cbaed8c6c fix: add status checking for embedding model providers 2025-07-11 18:36:40 +08:00
Soulter
2915fdf665 release: v3.5.22 2025-07-11 12:29:26 +08:00
Soulter
a66c385b08 fix: deadlock when docker is not available 2025-07-11 12:27:49 +08:00
Raven95676
4dace7c5d8 chore: format code 2025-07-11 11:23:53 +08:00
Soulter
8ebf087dbf chore: optimize codes 2025-07-10 23:28:00 +08:00
Soulter
2fa8bda5bb chore: ruff lint 2025-07-10 23:23:29 +08:00
Soulter
a5ae833945 📦 release: v3.5.21 2025-07-10 17:46:36 +08:00
Soulter
d21d42b312 chore: update icon URL for 302.AI to use color version 2025-07-10 17:44:11 +08:00
Soulter
78575f0f0a fix: failed to delete conversation in webchat
fixes: #2071
2025-07-10 17:04:34 +08:00
Soulter
8ccd292d16 Merge pull request #2082 from AstrBotDevs/fix-webchat-segment-reply
fix: 修复 WebChat 下可能消息错位的问题
2025-07-10 17:00:14 +08:00
Soulter
2534f59398 chore: remove debug print statement from chat route 2025-07-10 16:59:58 +08:00
Soulter
5c60dbe2b1 fix: 修复 WebChat 下可能消息错位的问题 2025-07-10 16:52:16 +08:00
Soulter
c99ecde15f Merge pull request #2078 from AstrBotDevs/fix-webchat-image-cannot-render
Fix: webchat cannot render image and audio image normally
2025-07-10 11:57:50 +08:00
Soulter
219f3403d9 fix: webchat cannot render image and audio image normally 2025-07-10 11:51:47 +08:00
Soulter
00f417bad6 Merge pull request #2073 from Raven95676/fix/register_star
fix: 提升兼容性,并尽可能避免数据竞争
2025-07-10 11:03:57 +08:00
Soulter
81649f053b perf: improve log 2025-07-10 10:58:56 +08:00
Raven95676
e5bde50f2d fix: 提升兼容性,并尽可能避免数据竞争 2025-07-09 22:39:30 +08:00
Raven95676
0321e00b0d perf: 移除nh3 2025-07-09 20:32:14 +08:00
Soulter
09528e3292 docs: add model providers 2025-07-09 14:18:59 +08:00
Soulter
e7412a9cbf docs: add model providers 2025-07-09 14:17:39 +08:00
Soulter
01efe5f869 📦 release: v3.5.20 2025-07-09 13:35:44 +08:00
Soulter
28a178a55c Merge pull request #2067 from AstrBotDevs/refactor-aiocqhttp-send-message
Fix: active message cannot handle forward type message properly in aiocqhttp adapter
2025-07-09 13:23:08 +08:00
Soulter
88f130014c perf: streamline message dispatching logic in AiocqhttpMessageEvent 2025-07-09 12:10:18 +08:00
Soulter
af258c590c Merge pull request #2068 from AstrBotDevs/fix-tool-call-result-wrongly-sent
Fix: 修复工具调用被错误地发出到了消息平台上
2025-07-09 12:02:07 +08:00
Soulter
b0eb5733be Merge pull request #2065 from AstrBotDevs/fix-plugin-metadata-load
Improve: add fallback for missing 'desc' in plugin metadata
2025-07-09 12:01:06 +08:00
Soulter
fe35bfba37 Merge pull request #2064 from uersula/fix-image-removal-flag-logic
Fix: 移除 _remove_image_from_context中的flag逻辑
2025-07-09 12:00:30 +08:00
advent259141
7cfbc4ab8f 增加了针对整个会话启停的开关 2025-07-09 11:58:52 +08:00
Soulter
7a9d4f0abd fix: 修复工具调用被错误地发出到了消息平台上
fixes: #2060
2025-07-09 11:43:25 +08:00
Soulter
6f6a5b565c fix: active message cannot handle forward type message properly in aiocqhttp adapter 2025-07-09 11:19:32 +08:00
Soulter
e57deb873c perf: add fallback for missing 'desc' in plugin metadata and improve error logging 2025-07-09 10:47:03 +08:00
Gao Jinzhe
0f692b1608 Merge branch 'master' into Astrbot_session_manage 2025-07-09 10:13:51 +08:00
uersula
8c03e79f99 Fix: Remove buggy flag logic in _remove_image_from_context 2025-07-08 23:01:11 +08:00
Soulter
71290f0929 Merge pull request #2061 from AstrBotDevs/feat-handle-image-in-quote-message
Feature: 支持对引用消息中的图片内容进行理解
2025-07-08 22:11:17 +08:00
Soulter
22364ef7de feat: 支持对引用消息中的图片内容进行理解
fixes: #2056
2025-07-08 22:08:40 +08:00
Ruochen
2cc1eb1abc feat:实现了speech_to_text类型的供应商可用性检查 2025-07-08 21:55:31 +08:00
RC-CHN
90dbcbb4e2 Merge branch 'AstrBotDevs:master' into master 2025-07-08 21:28:50 +08:00
Ruochen
66503d58be feat:实现了text_to_speech类型的供应商可用性测试 2025-07-08 17:52:22 +08:00
Ruochen
8e10f0ce2b feat:实现了embedding类型的供应商可用性检查 2025-07-08 16:51:57 +08:00
Soulter
f51f510f2e perf: enhance date handle in reminder
fixes: #1901
2025-07-08 16:33:46 +08:00
Ruochen
c44f085b47 fix:对非文本生成类供应商暂时跳过测试 2025-07-08 16:32:39 +08:00
RC-CHN
a35f36eeaf Merge branch 'AstrBotDevs:master' into master 2025-07-08 15:34:19 +08:00
Ruochen
14564c392a feat:meta方法增加provider_type字段 2025-07-08 15:33:02 +08:00
Soulter
76e05ea749 Merge pull request #2022 from AstrBotDevs/deprecate/register_star-decorator
[Deprecation] 弃用register_star装饰器
2025-07-08 11:57:28 +08:00
Soulter
ab599dceed Merge branch 'master' into deprecate/register_star-decorator 2025-07-08 11:52:33 +08:00
Soulter
4c37604445 perf: only output deprecation warning once for @register_star decorator 2025-07-08 11:50:55 +08:00
Soulter
bb74018d19 Merge pull request #1998 from diudiu62/feat-wechatpadpro-adapter
增加监听wechatpadpro消息平台的事件
2025-07-08 11:40:13 +08:00
Soulter
575289e5bc feat: complete platform adapter types and update mapping 2025-07-08 11:39:42 +08:00
Soulter
e89da2a7b4 Merge pull request #2035 from cclauss/patch-1
pytest recommendation: `pip install --editable .`
2025-07-08 11:35:34 +08:00
Soulter
bd34959f68 📦 release: v3.5.19 2025-07-08 01:34:08 +08:00
Soulter
622dcf8fd5 fix: 通过指令选择提供商重启后失效 2025-07-08 01:24:19 +08:00
Soulter
9e315739b7 Merge pull request #2051 from AstrBotDevs/perf-ui
Improve: 改善 WebUI 效果
2025-07-08 00:35:52 +08:00
Soulter
7b01adc5df perf: better webui 2025-07-08 00:33:22 +08:00
Soulter
432fc47443 feat: add 302.ai llm provider 2025-07-07 23:01:28 +08:00
Soulter
d8fba44c5e Merge pull request #2049 from uersula/fix/keyerror-in-recovery-handler
Fix: 防止错误恢复机制_remove_image_from_context发生KeyError
2025-07-07 22:13:43 +08:00
Soulter
e29d3d8c01 Merge pull request #2043 from Zhenyi-Wang/master
fix(wechatpadpro): 修复授权码提取逻辑以兼容新旧接口格式
2025-07-07 22:10:20 +08:00
uersula
e678413214 Fix: Prevent KeyError in _remove_image_from_context 2025-07-07 02:30:50 +08:00
Soulter
eaa9d9d087 Merge pull request #2027 from IGCrystal/Branch-2
🐞 fix(WebUI): 解决XSS注入的问题
2025-07-06 18:13:40 +08:00
Soulter
9e3cc076b7 🐞 fix(ReadmeDialog): add variant attribute to close button for consistency 2025-07-06 18:13:00 +08:00
IGCrystal
3bb01fa52c feat(ChatPage): 添加图像预览 2025-07-06 18:08:17 +08:00
IGCrystal
008e49d144 🎈 perf: 优化音频附件的显示 2025-07-06 18:08:17 +08:00
IGCrystal
4e275384b0 🐞 fix(VerticalHeader): 允许HTML渲染 2025-07-06 18:08:17 +08:00
IGCrystal
63ec99f67a 🐞 fix: 添加不存在的翻译键 2025-07-06 18:08:17 +08:00
IGCrystal
14a8bb57df 🐞 fix(WebUI): 解决XSS注入的问题 2025-07-06 18:08:17 +08:00
Soulter
7512bfc710 fix: update user message bubble styling for improved appearance 2025-07-06 18:06:28 +08:00
Soulter
3c3b6dadc3 Merge pull request #2037 from AstrBotDevs/fix/tool_call_result
fix: direct send tool_call_result
2025-07-06 18:05:59 +08:00
Soulter
cd722a0e39 fix: handle direct tool call results 2025-07-06 18:04:46 +08:00
Soulter
a1b5d0a100 Merge remote-tracking branch 'origin/master' into fix/tool_call_result 2025-07-06 17:47:09 +08:00
Raven95676
69d3ae709c fix: direct send tool_call_result 2025-07-06 17:45:07 +08:00
Soulter
67ef993d61 fix: webchat message bubble style 2025-07-06 17:21:57 +08:00
Soulter
20f49890ad fix: provider selection for updating webchat title 2025-07-06 17:18:37 +08:00
Zhenyi Wang
3e4917f0a1 refactor: 重构 wechatpadpro 授权码生成并增强安全性
- 将 generate_auth_key 方法中的授权码提取逻辑重构为新的辅助方法 _extract_auth_key ,以提高代码的可读性和可测试性。
- 在访问 data.get('authKeys') 之前添加 isinstance(data, dict) 检查,以防止潜在的 AttributeError 。
- 移除了 auth_key 的明文日志记录,以避免敏感信息泄露。
- 在生成新密钥之前,将 self.auth_key 初始化为 None ,以避免在失败时保留旧值。
2025-07-06 16:34:55 +08:00
Soulter
99ee75aec6 Merge pull request #2029 from jiongjiongJOJO/master
fix: 增加演示模式下校验插件开启/关闭/安装指令
2025-07-06 16:24:02 +08:00
Zhenyi Wang
1674653a42 fix(wechatpadpro): 修复授权码提取逻辑以兼容新旧接口格式
新接口返回多了一层authKeys字段,同时兼容二者
2025-07-06 16:18:31 +08:00
Christian Clauss
d2f7e55bf5 Run the tests on pull requests 2025-07-05 13:57:58 +02:00
Christian Clauss
9f31df7f3a pytest recommendation: pip install --editable .
https://docs.pytest.org/en/stable/how-to/existingtestsuite.html

This makes setting `PYTHONPATH` unnecessary and will pull requirements from `pyproject.toml` instead of `requirements.txt`, so it is similar to end-user installations.

`makedir -p data/plugins` will do both `mkdir data` and `mkdir data/plugins`.

The `$CI` environment variable might be better to use than `$TESTING` because it is preset to `true` in GitHub Actions.
* https://docs.github.com/en/actions/reference/variables-reference#default-environment-variables
* https://docs.pytest.org/en/stable/explanation/ci.html
2025-07-05 13:52:28 +02:00
Soulter
b8c1b53d67 Merge pull request #2034 from AstrBotDevs/dependabot/github_actions/github-actions-50e66c4123
chore(deps): bump the github-actions group with 4 updates
2025-07-05 19:24:16 +08:00
dependabot[bot]
2495837791 chore(deps): bump the github-actions group with 4 updates
Bumps the github-actions group with 4 updates: [actions/checkout](https://github.com/actions/checkout), [actions/setup-python](https://github.com/actions/setup-python), [codecov/codecov-action](https://github.com/codecov/codecov-action) and [actions/stale](https://github.com/actions/stale).


Updates `actions/checkout` from 3 to 4
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

Updates `actions/setup-python` from 4 to 5
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v4...v5)

Updates `codecov/codecov-action` from 4 to 5
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/codecov/codecov-action/compare/v4...v5)

Updates `actions/stale` from 5 to 9
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/stale/compare/v5...v9)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: codecov/codecov-action
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/stale
  dependency-version: '9'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-07-05 11:20:25 +00:00
Soulter
b6562e3c47 Merge pull request #2005 from cclauss/patch-1
Keep GitHub Actions up to date with GitHub's Dependabot
2025-07-05 19:19:18 +08:00
Soulter
c57da046ee Merge pull request #2013 from AstrBotDevs/feat/danger-plugin
[Copilot] feat: 添加风险插件安装确认对话框以及风险插件标签特殊处理
2025-07-05 19:18:14 +08:00
JOJO
ff63134c14 fix: 增加演示模式下校验插件开启/关闭/安装指令 2025-07-05 12:43:19 +08:00
鸦羽
3f5210c587 chore: update plugin publish template 2025-07-04 22:28:00 +08:00
IGCrystal
3df5e7b9b9 🐞 fix: 添加tags.danger的翻译键 2025-07-04 17:28:39 +08:00
Soulter
225db66738 fix: refine streaming logic in chat response handling 2025-07-04 16:59:49 +08:00
Soulter
383ebb8f57 feat: add copy functionality for bot messages with success feedback 2025-07-04 16:27:52 +08:00
Raven95676
e1bed60f1f fix: adjust timing of adding to star_registry 2025-07-04 16:13:10 +08:00
Raven95676
edbb856023 refactor: deprecate register_star decorator 2025-07-04 15:54:23 +08:00
Raven95676
98d3ab646f chore: convert some methods to static 2025-07-04 15:07:14 +08:00
Soulter
81be556f1b Merge pull request #2018 from AstrBotDevs/fix-extension-btn-z-index
Fix: adjust z-index for the add button on ExtensionPage
2025-07-04 11:41:10 +08:00
Soulter
f45a085469 fix: adjust z-index for the add button on ExtensionPage
fixes: #1985
2025-07-04 11:40:14 +08:00
Raven95676
210cc58cc3 fix: 更新风险插件警告对话框内容和按钮文本,修正样式 By @Soulter
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-07-04 11:23:19 +08:00
Soulter
1063b11ef6 fix: check provider availability errors on dify 2025-07-04 10:19:58 +08:00
Raven95676
a4e999c47f feat: 添加风险插件安装确认对话框以及风险插件标签特殊处理 2025-07-03 22:16:00 +08:00
Soulter
543e01c301 perf: webui 删除对话使用 conversation_mgr,以保持状态同步 2025-07-03 15:44:45 +08:00
Soulter
14e0aa3ec5 perf: history 和 persona 指令当对话不存在的时候自动创建
fixes: #1997
2025-07-03 15:40:00 +08:00
Christian Clauss
1a8a171f8b Keep GitHub Actions up to date with GitHub's Dependabot
* [Keeping your software supply chain secure with Dependabot](https://docs.github.com/en/code-security/dependabot)
* [Keeping your actions up to date with Dependabot](https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot)
* [Configuration options for the `dependabot.yml` file - package-ecosystem](https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem)
2025-07-03 08:46:42 +02:00
Soulter
f1954f9a43 Merge pull request #1984 from RC-CHN/master
refactor:将前端测试供应商部分修改为独立并发异步获取各个文本供应商的状态
2025-07-03 10:55:51 +08:00
Soulter
441b148501 Merge pull request #1991 from AstrBotDevs/perf/webchat-title
perf: 优化WebChat对话标题生成
2025-07-03 10:53:35 +08:00
Soulter
bd0f30b81c Merge pull request #2003 from AstrBotDevs/feat-webchat-select-provider
Feature: WebChat 增加可选择提供商和模型的功能
2025-07-03 10:52:42 +08:00
Soulter
ad14e9bf40 chore: remove unnecessary logging of payloads in chat completion 2025-07-03 10:50:03 +08:00
Soulter
6f71301aaf fix: log error when selected provider is not found 2025-07-03 10:49:12 +08:00
Soulter
5f0d601baa feat: add support for selecting provider and models in webchat 2025-07-03 10:42:20 +08:00
Soulter
f234a5bcc2 fix: enhance event hook handling to return status and prevent propagation 2025-07-03 00:23:56 +08:00
chenpeng
ab677ea100 修正pilk依赖提示文案
增加监听wechatpadpro消息平台的事件
2025-07-02 17:30:37 +08:00
Soulter
f3ad53e949 feat: add supports for selecting provider and models in webchat 2025-07-02 17:12:30 +08:00
Soulter
d324cfa84d Merge pull request #1987 from AstrBotDevs/refactor-webchat-streaming
Refactor: 重构 WebChat 的 SSE 监听逻辑
2025-07-02 17:11:12 +08:00
Soulter
dd4319d72a Merge pull request #1990 from AstrBotDevs/fix-stream-multi-tool-use-err
fix: Multi-turn tools use error when using streaming output
2025-07-02 15:44:29 +08:00
Raven95676
1f2de3d3d8 perf: 优化WebChat对话标题生成 2025-07-02 10:43:54 +08:00
Raven95676
72702beb0b chore: clean code 2025-07-02 10:29:10 +08:00
Soulter
adb0cbc5dd fix: handle tool_calls_result as list or single object in context query in streaming mode 2025-07-02 10:16:44 +08:00
Soulter
6a503b82c3 refactor: web chat queue management and streamline chat route handling 2025-07-01 22:34:17 +08:00
advent259141
28a87351f1 新增对会话重命名的功能 2025-07-01 21:41:19 +08:00
Soulter
bcc97378b0 feat: implement code copy functionality and enhance code highlighting in ChatPage 2025-07-01 21:15:01 +08:00
Soulter
eb8a138713 feat: enhance conversation actions with delete functionality and improved styling 2025-07-01 21:00:43 +08:00
advent259141
dcd7dcbbdf 解决了conflict 2025-07-01 17:24:56 +08:00
Gao Jinzhe
1538759ba7 Merge branch 'master' into Astrbot_session_manage 2025-07-01 17:19:30 +08:00
Soulter
30e8ea7fd8 chore: add deploy badge 2025-07-01 16:59:58 +08:00
Ruochen
879b7b582c perf:提取重复的错误处理逻辑,优化循环调用 2025-07-01 16:02:56 +08:00
Ruochen
8ba4236402 refactor:将前端测试供应商部分修改为独立并发异步获取各个文本供应商的状态 2025-07-01 15:41:30 +08:00
鸦羽
5eef8fa9b9 Merge pull request #1981 from AstrBotDevs/feat/r1_filter-integration
feat: 集成r1_filter至框架
2025-07-01 13:56:01 +08:00
Raven95676
d03d035437 perf: 合并嵌套的if条件 2025-07-01 13:53:22 +08:00
Raven95676
68e8e1f70b feat: 集成r1_filter至框架 2025-07-01 12:40:52 +08:00
Soulter
7acb45b157 Update README.md 2025-07-01 11:35:14 +08:00
Soulter
c36142deaf perf: chatpage UI 2025-06-30 15:20:46 +08:00
Soulter
5fd6e316fa Merge pull request #1966 from railgun19457/master
修改了一对大括号
2025-06-30 13:33:10 +08:00
railgun19457
39a9d7765a 修改了一对大括号 2025-06-30 00:21:28 +08:00
Soulter
7cfcba29a6 feat: add loading state for dashboard update process 2025-06-29 21:55:13 +08:00
Soulter
9bf8aadca9 📦 release: v3.5.18 2025-06-29 21:52:45 +08:00
Soulter
714d4af63d Merge pull request #1963 from AstrBotDevs/refactor-llm-request
Refactor: 将 LLM Request 部分抽象为 AgentRunner 并优化多轮工具调用
2025-06-29 21:38:43 +08:00
Soulter
8203fdb4f0 fix: webchat show tool call 2025-06-29 21:35:39 +08:00
Soulter
5e1e2d1a4f perf: 优化 ChatPage UI 2025-06-29 21:19:52 +08:00
Soulter
2f941de65b feat: 支持展示工具使用过程 2025-06-29 21:19:40 +08:00
Raven95676
777c503002 perf: change logging level to debug for agent state transitions and LLM responses 2025-06-29 17:32:53 +08:00
Raven95676
e9b23f68fd perf: add AgentState Enum for improved state management 2025-06-29 17:19:53 +08:00
Soulter
efa45e6203 fix: validate and repair message contexts in LLMRequestSubStage 2025-06-29 16:36:08 +08:00
Raven95676
638f55f83c Merge branch 'refactor-llm-request' of https://github.com/AstrBotDevs/AstrBot into refactor-llm-request 2025-06-29 16:13:18 +08:00
Raven95676
8b2fc29d5b chore: remove accidentally committed file 2025-06-29 16:13:15 +08:00
Soulter
b516fb0550 chore: remove dump_plugins.py 2025-06-29 16:12:40 +08:00
Raven95676
efef34c01e style: format code 2025-06-29 16:06:44 +08:00
Soulter
5f1dfa7599 fix: handle LLM response and execute event hook in ToolLoopAgent 2025-06-29 15:58:22 +08:00
Soulter
8e9c7544cf fix: update type check for async generator in PipelineContext 2025-06-29 15:54:32 +08:00
Soulter
4e3d5641c8 chore: code quality 2025-06-29 15:51:56 +08:00
Soulter
20b760529e fix: anthropic api error when using tools 2025-06-29 15:33:08 +08:00
Soulter
a55a07c5ff remove: useless provider init params 2025-06-29 14:43:36 +08:00
Soulter
94ee8ea297 feat: 支持多轮次工具调用并且存储到数据库
移除了 llm tuner 适配器
2025-06-29 14:27:00 +08:00
advent259141
ec5d71d0e1 修复了一下重复的代码问题,删除了不必要的会话级别 LLM 启停状态检查。 2025-06-29 10:02:04 +08:00
advent259141
d121d08d05 大致凭借自己理解修复了一下整个检查流程,防止钩子出现问题 2025-06-29 09:57:31 +08:00
Gao Jinzhe
be08f4a558 Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-29 09:11:25 +08:00
Soulter
010f082fbb Merge pull request #1914 from HakimYu/master
fix(AiocqhttpAdapter): 修复at_info.get("nick", "")的错误
2025-06-28 21:52:01 +08:00
Soulter
073cdf6d51 perf: also consider nick 2025-06-28 21:51:10 +08:00
Soulter
4df8606ab6 style: code quality 2025-06-28 20:08:57 +08:00
Soulter
71442d26ec chore: 移除不必要的 MCP 会话控制 2025-06-28 19:58:36 +08:00
advent259141
4f5528869c Merge branch 'Astrbot_session_manage' of https://github.com/advent259141/AstrBot into Astrbot_session_manage 2025-06-28 17:00:00 +08:00
advent259141
f16feff17b 根据会话mcp开关情况选择性传入 func_tool
修改import的位置
deleted:    astrbot/core/star/session_tts_manager.py
复原被覆盖的修改
2025-06-28 16:59:00 +08:00
Soulter
71b233fe5f Merge pull request #1942 from QiChenSn/fix-CommandFilter-ParseForBool
fix:修复commandfilter对布尔类型的解析
2025-06-28 15:10:29 +08:00
Soulter
770dec9ed6 fix: handle boolean parameter parsing correctly in CommandFilter 2025-06-28 15:08:19 +08:00
Soulter
2ca95a988e fix: lint warnings 2025-06-28 15:05:57 +08:00
Gao Jinzhe
d8aae538cd Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-28 14:55:38 +08:00
Soulter
cf1e7ee08a Merge pull request #1947 from RC-CHN/master
允许为html_render方法传递参数
2025-06-28 14:52:09 +08:00
Soulter
d14513ddfd fix: lint warnings 2025-06-28 14:51:35 +08:00
Soulter
9a9017bc6c perf: use union oper for merging dict
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-28 14:46:29 +08:00
Soulter
3c9b654713 Merge pull request #1923 from Magstic/patch-1
Fix: 仪表盘的『插件配置』中不显示 JSON 编辑窗
2025-06-28 14:45:14 +08:00
Magstic
80d2ad40bc fix: 仪表盘的『插件配置』中不显示 JSON 编辑窗
该提交与 #1919 关联。

精准定位错误 @Pine-Ln,Fix from Gemini 2.5 Pro.

这个问题是由两个错误叠加造成的:

1. **组件崩溃**:`AstrBotConfig.vue` 混用了 Vue 3 的 `<script setup>` 和旧式 `<script>` 写法,导致作用域冲突,模板无法访问国际化函数 `t`,引发 `ReferenceError: t is not defined`。

2. **设置项不显示**:原代码根据用户已保存的设置数据来渲染字段,导致新增的设置项(如 `editor_mode`)因为用户配置中没有初始值而不显示。

1. **统一 API 写法**:将整个组件重构为纯 `<script setup>` 写法,解决作用域冲突。

2. **修正渲染逻辑**:将 `v-for` 循环改为遍历设置蓝图 (metadata) 而不是用户数据,确保所有定义的设置项都能显示。
2025-06-28 14:42:06 +08:00
advent259141
31670e75e5 Merge branch 'Astrbot_session_manage' of https://github.com/advent259141/AstrBot into Astrbot_session_manage 2025-06-27 18:47:25 +08:00
advent259141
ed6011a2be modified: dashboard/src/i18n/loader.ts
modified:   dashboard/src/i18n/locales/en-US/core/navigation.json
增加会话管理英文页面
	modified:   dashboard/src/i18n/locales/zh-CN/core/navigation.json
增加会话管理中文页面
	modified:   dashboard/src/i18n/translations.ts
	modified:   dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
	modified:   dashboard/src/views/SessionManagementPage.vue
增加会话管理国际化适配
2025-06-27 18:46:02 +08:00
Gao Jinzhe
cdded38ade Merge branch 'AstrBotDevs:master' into Astrbot_session_manage 2025-06-27 17:10:08 +08:00
advent259141
f536f24833 astrbot/core/pipeline/process_stage/method/llm_request.py
astrbot/core/pipeline/result_decorate/stage.py
   astrbot/core/star/session_llm_manager.py
   astrbot/core/star/session_tts_manager.py
   astrbot/dashboard/routes/session_management.py
   astrbot/dashboard/server.py
   dashboard/src/views/SessionManagementPage.vue
   packages/astrbot/main.py
2025-06-27 17:08:05 +08:00
Ruochen
f5bff00b1f Merge branch 'master' of https://github.com/RC-CHN/AstrBot 2025-06-27 17:03:58 +08:00
Ruochen
27c9717445 feat:允许html_render方法传入配置参数 2025-06-27 17:03:26 +08:00
Soulter
863a1ba8ef Merge pull request #1922 from SXP-Simon/master
[feat] (discord_platform_adapter) 增加了对机器人 Role Mention 方法的响应,并且修复了控制面板上 Discord 平台无法优雅重载的 Bug
2025-06-27 14:59:37 +08:00
Soulter
cb04dd2b83 chore: remove unnecessary codes 2025-06-27 14:59:08 +08:00
Soulter
8c7cf51958 chore: code format 2025-06-27 14:46:23 +08:00
Soulter
244fb1fed6 chore: remove useless logger 2025-06-27 14:38:31 +08:00
Soulter
25f7a68a13 Merge pull request #1709 from shuiping233/fix-qq-offical-session-bug
fix: qq_official适配器使用SessionController(会话控制)功能时机器人回复消息无法发送到聊天平台
2025-06-27 14:35:54 +08:00
Soulter
62d8cf79ef fix: remove deprecated pre_send and post_send calls for specific platforms 2025-06-27 14:31:35 +08:00
Gao Jinzhe
646b18d910 Merge branch 'AstrBotDevs:master' into master 2025-06-27 12:26:15 +08:00
QiChenSn
2f81b2e381 fix:修复commandfilter对布尔类型的解析 2025-06-27 02:32:10 +08:00
Soulter
1f5a7e7885 Merge pull request #1940 from AstrBotDevs/fix-tg-active-reply
fix: cannot make active reply in telegram
2025-06-27 00:05:10 +08:00
Soulter
80fca470f2 fix: cannot make active reply in telegram
Co-authored-by: youtiaoguagua <cloudcranesss@210625568+cloudcranesss@users.noreply.github.com>
2025-06-27 00:04:25 +08:00
Soulter
6e9d9ac856 Merge pull request #1907 from IGCrystal/Branch-2
🐞 fix(WebUI): 修复安装插件按钮不可见
2025-06-26 23:28:37 +08:00
Soulter
8d6fada1eb feat(ExtensionPage): show confirm dialog when click install plugin button 2025-06-26 23:25:59 +08:00
Soulter
3e715399a1 fix: 环境变量代理被忽略 (#1895) 2025-06-26 08:52:33 +08:00
Soulter
81cc8831f9 docs: update plugin issue template
docs: issue template

docs: update issue template

docs: update plugin issue template

fix: issue plugin template

docs: update plugin issue template
2025-06-26 08:28:28 +08:00
Soulter
f7370044a7 Merge pull request #1903 from IGCrystal/branch-1
 feat: 对PlatformPage使用翻译键
2025-06-25 22:49:03 +08:00
Soulter
51b015a629 Merge pull request #1830 from zhx8702/feat-wechat-tts-mp3towav
feat: wechatpadpro 触发tts时 添加对mp3格式音频支持
2025-06-25 22:46:10 +08:00
Soulter
392af7a553 fix: add pydub to requirements 2025-06-25 22:31:44 +08:00
鸦羽
d2dd07bad7 Merge pull request #1920 from AstrBotDevs/feat/gemini-tts
feat: 增加Gemini TTS API实现
2025-06-25 14:05:04 +08:00
回归天空
cebcd6925a [fix] (discord_platform_adapter) 解决了 “Discord 平台无法优雅重载” 的 bug
#### 问题现象(AI总结)

- 在通过 Web 面板或配置变更热重载 Discord 平台时,适配器的 terminate() 方法会被调用,但经常出现“卡死”或长时间无响应,导致 Discord 平台无法优雅重载。

- 日志显示停留在“正在清理已注册的斜杠指令...”等步骤,甚至出现超时或异常。

#### 2. 原因分析

- 适配器的 terminate() 方法中,涉及多个异步操作(如取消 polling 任务、清理斜杠指令、关闭客户端)。

- 某些 await 操作(如 await self.client.sync_commands() 或 await self.client.close())在网络异常、事件循环被取消等情况下,可能会阻塞或抛出 CancelledError,导致整个重载流程卡住。

- 之前的实现没有对这些 await 操作加超时保护,也没有分步日志,难以定位具体卡点。

#### 3. 修复措施

- 分步日志:在 terminate() 的每个关键步骤前后都加了详细日志,便于定位卡点。

- 超时保护:对所有关键 await 操作(如 polling 任务取消、指令清理、客户端关闭)都加了 asyncio.wait_for(..., timeout=10),防止无限阻塞。

- 健壮性提升:先 cancel polling 任务,再清理指令,最后关闭客户端。每一步都捕获异常并输出日志,保证即使某一步失败也能继续后续清理。

- 避免重复终止:移除了 run() 方法中的 finally: await self.terminate(),只允许外部统一调度,防止重复调用导致资源冲突或日志重复。

#### 4. 修复效果

- 现在 Discord 平台适配器在热重载或终止时,能优雅地依次完成所有清理步骤,不会因某一步阻塞导致整个流程卡死。
2025-06-25 11:46:49 +08:00
回归天空
e7b4357fc7 [feat] (discord_platform_adapter) 增加了对机器人 Role Mention 方法的响应 2025-06-25 11:41:55 +08:00
Raven95676
dc279dde4a fix: 简化get_audio方法中的提示文本生成逻辑,清除冗余判断逻辑 2025-06-25 10:55:51 +08:00
Raven95676
c0810a674f feat: 增加Gemini TTS API实现 2025-06-25 10:50:04 +08:00
HakimYu
0760cabbbe feat(AiocqhttpAdapter): 修复reply类型的 Event.from_payload报错 2025-06-24 17:20:30 +08:00
HakimYu
3b149c520b fix(AiocqhttpAdapter): 修复at_info.get("nick", "")的错误,并在message_str中针对At类型添加QQ号 2025-06-24 16:30:23 +08:00
Soulter
3d19fc89ff docs: 10k star banner 2025-06-24 02:07:23 +08:00
Soulter
cd1b1919f4 docs: 10k star banner 2025-06-24 01:51:46 +08:00
IGCrystal
0ed646eb27 🐞 fix(WebUI): 修复安装插件按钮不可见 2025-06-23 19:41:56 +08:00
邹永赫
c0c5859c99 Merge pull request #1905 from zouyonghe/master
使用定义的Plain类型代替原始基础类型str,保持代码统一性
2025-06-23 18:52:56 +09:00
邹永赫
a47121b849 使用定义的Plain类型代替原始基础类型str,保持代码统一性 2025-06-23 18:49:47 +09:00
邹永赫
d9dd20e89a Merge pull request #1904 from zouyonghe/master
修复代码重构造成的无法向前兼容在node中发送简单文本信息的问题
2025-06-23 18:20:52 +09:00
邹永赫
ed4609ebe5 修复代码重构造成的无法向前兼容在node中发送简单文本信息的问题 2025-06-23 18:17:37 +09:00
Gao Jinzhe
e24225c828 Merge branch 'master' into master 2025-06-23 15:21:08 +08:00
IGCrystal
01ef86d658 feat: 对PlatformPage使用翻译键 2025-06-23 14:44:06 +08:00
Soulter
cd4802da04 Merge pull request #1902 from railgun19457/master
修复plugin_enable配置无法保存的问题
2025-06-23 13:30:31 +08:00
Misaka Mikoto
2aca65780f Merge branch 'AstrBotDevs:master' into master 2025-06-23 13:29:31 +08:00
Soulter
2c435f7387 Merge pull request #1899 from IGCrystal/branch-1
🐞 fix: 显示运行时长国际化
2025-06-23 13:21:59 +08:00
Soulter
cc1afd1a9c Merge pull request #1900 from AstrBotDevs/fix-hc-jwt
Fix: JWT secret issue
2025-06-23 13:16:08 +08:00
railgun19457
6f098cdba6 修复plugin_enable配置无法保存的问题 2025-06-23 13:06:46 +08:00
Soulter
d03e9fb90a fix: jwt secret 2025-06-23 12:36:11 +08:00
IGCrystal
9f2966abe9 Merge branch 'branch-1' of https://github.com/IGCrystal/AstrBot into branch-1 2025-06-23 12:09:10 +08:00
IGCrystal
4e28ea1883 🐞 fix: 显示运行时长国际化 2025-06-23 12:08:27 +08:00
Soulter
289214e85c Merge pull request #1898 from IGCrystal/branch-1
🐞 fix(WebUI): 修复platform的logo路径问题
2025-06-23 11:59:58 +08:00
IGCrystal
a20d98bf93 🐞 fix(WebUI): 修复platform的logo路径问题 2025-06-23 11:57:20 +08:00
Soulter
7c3d98acbe 📦 release: v3.5.17
因为 pypi 不允许上传相同的文件名的 wheel
2025-06-23 01:17:38 +08:00
Soulter
7311786f48 fix(dependencies): remove optional 'speed' from py-cord dependency 2025-06-23 01:03:43 +08:00
Soulter
82de9c926e docs: update readme 2025-06-23 00:40:34 +08:00
Soulter
7fd86d4de3 docs: update readme 2025-06-23 00:38:52 +08:00
Soulter
724da29e2a 📦 release: bump to v3.5.16 2025-06-23 00:15:30 +08:00
Soulter
54113d7b94 Merge pull request #1896 from AstrBotDevs/perf-webui-dialog-logo
Improve: improve styles of creating adapter dialog
2025-06-23 00:03:50 +08:00
Soulter
66396e8290 perf(webui): improve styles of creating adapter dialog in platform and provider page 2025-06-23 00:01:04 +08:00
Soulter
72be76215f Merge pull request #1822 from IGCrystal/branch-1
 feat(WebUI): complete dashboard internationalization system refactor
2025-06-22 22:22:33 +08:00
Soulter
ace86703a9 Merge pull request #1888 from HakimYu/master
Discord 实现 SlashCommand 的注册、添加对 At 与 Reply 的支持、设置机器人 Activity
2025-06-22 22:19:19 +08:00
Soulter
7b25495463 style: code quality 2025-06-22 22:11:28 +08:00
HakimYu
3d4b651c1f fix: 修复 send_by_session 的 message_obj 为 None 的错误
fix: 修复 determine_messagee_type 会获取到服务器id的错误,并拆分成两个函数
2025-06-22 20:33:26 +08:00
HakimYu
d305ae064d Merge branch 'AstrBotDevs:master' into master 2025-06-22 16:29:38 +08:00
HakimYu
ac4f3d8907 feat: 添加 Discord 斜杠指令注册功能及相关配置项
feat: 添加 Activity 设置项
fix: 修复 At Reply 未处理的问题
2025-06-22 16:29:02 +08:00
Soulter
af2687771b ci: update dashboard ci to support pull request 2025-06-22 10:38:09 +08:00
Soulter
a67b7f909a Merge branch 'master' into branch-1 2025-06-22 10:28:44 +08:00
Soulter
f9c3e4cdb0 Merge pull request #1821 from Zhalslar/gsv-tts-selfhost
Feature: 新增 GPT_SoVIS 的 TTS 服务商
2025-06-21 23:58:07 +08:00
Soulter
dc62c1f8d4 style: code format 2025-06-21 23:56:06 +08:00
Soulter
0441b51a68 Merge pull request #1867 from lxfight/master
Feature: 添加 Discord 平台适配器及相关组件,支持 Discord Bot 功能
2025-06-21 23:52:54 +08:00
Soulter
5c0c9f687e style: code quality 2025-06-21 23:52:17 +08:00
Soulter
e049c54043 chore: update uv.lock 2025-06-21 23:33:58 +08:00
Soulter
99e47540d5 styles: code quality 2025-06-21 23:33:47 +08:00
Soulter
8e1885ffeb Merge branch 'master' into master 2025-06-21 23:21:37 +08:00
Soulter
8501a0c205 perf: replace slack requirements 2025-06-21 23:19:39 +08:00
Soulter
797f2a3173 Merge pull request #1877 from AstrBotDevs/feat-adapter-slack
Feature: Add platform adapter support for Slack
2025-06-21 23:13:37 +08:00
Soulter
1057b4bc35 style: code quality 2025-06-21 23:12:50 +08:00
Soulter
efc0116595 feat: Verify Slack request signature using HMAC 2025-06-21 23:07:34 +08:00
Soulter
cdc560fad0 chore: remove useless codes 2025-06-21 22:58:30 +08:00
lxfight
75a2803710 fix: 清空交互事件的 message_str,确保仅专门指令处理器响应;优化图片处理逻辑,支持多种图片来源
- 修复了@激活机器人时,指令无法正确处理的问题
- 修复了base64 图片无法发送的问题

注意:本次提交的代码功能还需要针对全部功能进行一次系统完整的测试,计划与6月22日下午完成。
2025-06-21 20:12:38 +08:00
Soulter
fb3169faa4 feat: add platform adapter support for Slack 2025-06-21 18:33:48 +08:00
Soulter
d587bd837e Merge pull request #1845 from RC-CHN/master
feat:在用户未为服务商配置key时添加二次警告确认
2025-06-20 23:27:27 +08:00
lxfight
b9fab74edc feat: 拆分Discord 适配器的部分代码,并处理一些小的问题。
- 基于最小权限原则,修改了 Bot 申请的权限范围
- 拆分了代码,使得文件结构更加清晰
2025-06-20 21:43:23 +08:00
lxfight
50c22bbadb feat: 在 requirements.txt 中添加 py-cord[speed] 依赖 2025-06-20 21:26:55 +08:00
lxfight
d0b10b9195 feat: 添加 Discord 平台适配器及相关组件,支持 Discord Bot 功能
- 添加了一个新的依赖 py-cord[speed]
- 添加了针对 Discord 平台的 Discord Bot 适配器
2025-06-20 21:22:04 +08:00
Gao Jinzhe
50a296de20 Merge branch 'AstrBotDevs:master' into master 2025-06-20 14:39:57 +08:00
IGCrystal
c8fe4f4a3c Merge branch 'AstrBotDevs:master' into branch-1 2025-06-19 11:56:39 +08:00
IGCrystal
a8ba0720af 🎈 perf: 在更新弹窗中提高关闭按钮与控制台的间距
之前的按钮与控制台内容重叠了,就增加一点间距
2025-06-19 11:54:27 +08:00
IGCrystal
745a01246c 🎈 perf: 修改chat的弹窗样式 2025-06-19 10:30:33 +08:00
Zhalslar
bee5d3550f Merge branch 'gsv-tts-selfhost' of https://github.com/Zhalslar/AstrBot_Zhalslar into gsv-tts-selfhost 2025-06-19 00:52:16 +08:00
Zhalslar
1789393151 提供initialize和terminate方法对接上游 2025-06-19 00:52:03 +08:00
Soulter
345afe1338 fix: 修复 PipInstaller 中 pip 调用方式,确保使用当前 Python 解释器 2025-06-19 00:38:23 +08:00
Ruochen
65428aa49f perf: 优化服务商保存流程,并修复UI状态 2025-06-18 23:58:09 +08:00
Soulter
b251ee9322 perf: 优化空文本检测
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-18 23:45:59 +08:00
Ruochen
04f00682a0 Merge branch 'master' of https://github.com/RC-CHN/AstrBot 2025-06-18 23:43:09 +08:00
Ruochen
90dcda1475 feat:在用户未为服务商配置key时添加二次警告确认 2025-06-18 23:41:07 +08:00
IGCrystal
f1ee4eb89f 🐞 fix: 修改录音键位为Ctrl+B
Ctrl+A太常用了就修改了
2025-06-18 21:00:28 +08:00
IGCrystal
343fc22168 🎈 perf: 修改chat中录音的键位防止误触
修改键位为Ctrl + A ,以及还加入SSE断连提示
2025-06-18 17:58:15 +08:00
IGCrystal
00ef0d7e3d 🐞 fix: 修复无法实时显示消息
修复chat与chatbox之间切换后sse断开连接导致无法实时显示消息
2025-06-18 16:24:18 +08:00
IGCrystal
f2deaf6199 🎈 perf: 修改滚动条样式 2025-06-18 00:47:43 +08:00
IGCrystal
617a2c010e 🎈 perf: 优化登录页面样式
处理的是分隔线的样式
2025-06-17 22:20:48 +08:00
Gao Jinzhe
c79e38e044 Merge branch 'AstrBotDevs:master' into master 2025-06-17 20:29:32 +08:00
IGCrystal
38eae1d1ee 🐞 fix: 进一步的检查与校准 2025-06-17 12:22:00 +08:00
IGCrystal
7e4c89b0cb 🦄 refactor(i18n): replace manual types with auto-inference
- Migrate from manual TypeScript interfaces to automatic type generation
from JSON files. Eliminates sync issues and maintenance overhead.
2025-06-17 11:10:21 +08:00
Zhalslar
14c29f07bd 优化 2025-06-17 10:55:35 +08:00
Zhalslar
825e3dbcf5 Update default.py 2025-06-17 09:44:09 +08:00
IGCrystal
8275130f04 feat: 继续完成剩下的组件
- AlkaidPage_sigma.vue
- PlatformPage.vue
- LongTermMemory.vue
- KnowledgeBase.vue
2025-06-17 09:24:51 +08:00
Soulter
2c47abea95 fix: 修复 WeChatPadPro 下,开启了会话隔离后,无法发送群聊消息的问题
fixes: #1766
2025-06-16 23:36:11 +08:00
Soulter
85aa28d724 perf: print traceback 2025-06-16 23:27:29 +08:00
Soulter
53a3736b04 fix: 修复可能的类型错误
fixes: #1778
2025-06-16 23:26:22 +08:00
Soulter
86ba3c230e perf: 弱化 WeChatPadPro 的 WS 连接提示
fixes: #1779
2025-06-16 23:21:53 +08:00
Soulter
8d21126bd6 fix: 修复 WeChatPadPro 会话隔离模式下,会话 ID 显示为自身ID 的问题 2025-06-16 23:18:45 +08:00
Soulter
74ded91976 fix: 修复 WeChatPadPro 过期后无法正常的重新登录的问题。 2025-06-16 23:07:10 +08:00
IGCrystal
7c27520d57 feat: 继续完成剩下组件的国际化
ExtensionCard.vue - 插件卡片组件 WaitingForRestart.vue - 重启等待组件 ReadmeDialog.vue - README对话框组件 AstrBotConfig.vue - 配置编辑器组件 ListConfigItem.vue - 列表配置项组件 ItemCardGrid.vue - 卡片网格组件
ChatPage.vue - 聊天页面的录音提示文本 ConfigPage.vue - 配置页面的状态消息 ExtensionPage.vue - 插件页面的加载和状态文本 OnlineTime.vue - 仪表板运行时间组件
2025-06-16 22:44:44 +08:00
Soulter
b54bbc4c5a Merge pull request #1810 from Zhalslar/reply-bot-waking
feat:支持通过引用bot消息来唤醒bot
2025-06-16 21:56:17 +08:00
Soulter
3e09a4ddd4 Merge branch 'master' into reply-bot-waking 2025-06-16 21:55:50 +08:00
Zhalslar
f93f04a536 feat:支持通过引用bot消息来唤醒bot
Update dingtalk_event.py

Update stage.py
2025-06-16 21:54:13 +08:00
Soulter
b93f30b809 docs: update readme 2025-06-16 21:54:13 +08:00
Soulter
95bd2f26a5 Merge pull request #1812 from Zhalslar/dingtalk-image-to-url
feat:钉钉发图时自动将非HTTP图片注册成URL
2025-06-16 21:41:50 +08:00
IGCrystal
7cfcf056f9 🎈 perf: 使用 hash 路由模式以避免404 2025-06-16 21:36:23 +08:00
IGCrystal
96b565e1e8 🎈 perf: comprehensive dashboard improvements
- Enhance i18n error handling and code quality - Fix SSE data processing in chat page - Improve responsive design for extension page - Add better debugging tools for development"
2025-06-16 21:05:20 +08:00
IGCrystal
9d7ad7a18f 🐞 fix(i18n): resolve translation loading issues in production build 2025-06-16 20:14:00 +08:00
IGCrystal
9838c2758b 🐞 fix: resolve vue-i18n module augme 2025-06-16 19:08:19 +08:00
Soulter
1b1f5f5a5e docs(README.md): update logo 2025-06-16 19:06:46 +08:00
IGCrystal
0f95f62aa1 feat: 完成仪表板国际化系统重构
 核心特性:
- 实现模块化i18n架构,支持22个功能模块
- 完成中英双语翻译文件(44个翻译文件)
- 新增懒加载翻译模块,提升性能
- 类型安全的翻译键值验证系统

🌐 国际化覆盖:
- 所有主要页面(15+)完成国际化
- 导航侧边栏、顶栏、共享组件全部支持
- 仪表板统计组件完整国际化
- 登录页面及认证流程完整国际化

🎨 UI/UX 优化:
- 统一顶栏按钮样式(语言切换+主题切换)
- 移动端登录页采用全屏设计
- Logo组件智能换行支持中英文
- 响应式语言切换组件

📱 移动端适配:
- 登录卡片移动端全屏布局
- 悬浮工具栏底部固定定位
- 触摸友好的交互设计
- 多设备响应式支持

🔧 技术改进:
- 模块化翻译文件结构 (core/*, features/*)
- 懒加载机制减少初始包体积
- TypeScript类型定义完整
- 翻译键值自动验证
2025-06-16 13:53:33 +08:00
Zhalslar
9405ba7871 feat:新增GPT_SoVIS适配器 2025-06-16 13:45:50 +08:00
zhx
ccb95f803c feat: wechatpadpro 发送tts时 添加对mp3格式音频支持 2025-06-16 10:05:21 +08:00
Gao Jinzhe
dae745d925 Update server.py 2025-06-16 10:03:18 +08:00
Gao Jinzhe
791db65526 resolve conflict with master branch 2025-06-16 09:50:35 +08:00
IGCrystal
60b2ff0a7a 🐞 fix: 修复iframe跳转问题 2025-06-16 00:47:41 +08:00
IGCrystal
e6c8507379 📃 docs: 删除i18n的叙述 2025-06-15 23:19:46 +08:00
IGCrystal
420db5416e Merge branch 'branch-1' of https://github.com/IGCrystal/AstrBot into branch-1 2025-06-15 23:16:25 +08:00
IGCrystal
6e03218d54 feat: 多语言国际化支持 2025-06-15 23:10:44 +08:00
IGCrystal
5e4bd36b26 Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot 2025-06-15 23:07:03 +08:00
Soulter
bbc039366e Merge pull request #1816 from AstrBotDevs/refactor-webui-merge-extension-page
refactor(webui): 将插件市场与插件管理合并
2025-06-15 22:51:55 +08:00
Soulter
e1ec7dbbba refactor(webui): 将插件市场与插件管理合并 2025-06-15 22:50:06 +08:00
IGCrystal
075b008740 🐞 fix: 错误修复和代码健壮性
- 在 KnowledgeBase.vue 中修正无效的 v-class 指令为 class 属性的问题
- 在 ConsoleDisplayer.vue 中修正 historyNum 属性类型不匹配的问题
- 解决控制台中的 Vue 警告信息
- 在访问 status 前对 err.response 进行空值检查
- 防止“无法读取未定义对象的属性”错误
- 提高 catch 块中错误处理的健壮性
- 对 API 响应数据进行空值检查
- 在处理之前确保数组类型验证
- 修复“无法读取 null 对象的属性”错误
- 改进 beforeUnmount 生命周期中的 D3.js 清理工作
- 对图形数据处理添加防御性编程
2025-06-15 22:45:28 +08:00
IGCrystal
b2c382fa01 feat: 多语言国际化支持 2025-06-15 22:42:43 +08:00
Gao Jinzhe
02e2e617f5 Merge branch 'AstrBotDevs:master' into master 2025-06-15 22:04:06 +08:00
Soulter
c5f9b5861f Merge pull request #1804 from RC-CHN/master
feat:优化聊天页面的UI和UX
2025-06-15 21:22:23 +08:00
Soulter
2dace4c697 Merge pull request #1801 from IGCrystal/master
🎈 perf: 优化登录界面样式和侧边栏样式
2025-06-15 21:15:31 +08:00
Zhalslar
c7891385ca Update dingtalk_event.py 2025-06-14 21:44:37 +08:00
Zhalslar
2059ddcadf Update dingtalk_event.py 2025-06-14 21:39:33 +08:00
Zhalslar
ba1b68df20 Update dingtalk_event.py 2025-06-14 21:23:45 +08:00
advent259141
bfc8024119 modified: astrbot/core/pipeline/process_stage/method/llm_request.py
new file:   astrbot/core/star/session_llm_manager.py
	modified:   astrbot/dashboard/routes/session_management.py
	modified:   dashboard/src/views/SessionManagementPage.vue

 增加了精确到会话的LLM启停管理以及插件启停管理
2025-06-14 03:42:21 +08:00
Gao Jinzhe
f26cf6ed6f Merge branch 'AstrBotDevs:master' into master 2025-06-14 03:03:41 +08:00
Soulter
403b61836d docs: update readme 2025-06-14 02:09:06 +08:00
Ruochen
b5af7d1eb9 为chatbox模式添加了夜间模式切换 2025-06-13 23:11:09 +08:00
Ruochen
f453af6e4c feat:优化聊天页面的UI和UX 2025-06-13 21:30:56 +08:00
advent259141
f2be55bd8e Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 06:20:49 +08:00
advent259141
d241dd17ca Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 06:20:09 +08:00
advent259141
cecafdfe6c Merge branch 'master' of https://github.com/advent259141/AstrBot 2025-06-13 03:54:35 +08:00
Soulter
6fecfd1a0e Merge pull request #1800 from AstrBotDevs/feat-weixinkefu-record
feat: 微信客服支持语音的收发
2025-06-13 03:52:15 +08:00
IGCrystal
64245d001c Merge branch 'AstrBotDevs:master' into master 2025-06-13 00:59:21 +08:00
IGCrystal
7d92965cae 🎈 perf: 优化侧边栏样式 2025-06-12 23:51:44 +08:00
IGCrystal
b4fa08c4e2 🎈 perf: 优化登录界面样式 2025-06-12 23:26:01 +08:00
Soulter
d4e9566851 Merge pull request #1800 from AstrBotDevs/feat-weixinkefu-record
feat: 微信客服支持语音的收发
2025-06-12 23:02:22 +08:00
Soulter
a26b494f7f feat: 微信客服支持语音的收发
fixes: #1794
2025-06-12 10:57:16 -04:00
Soulter
b84e22e41f fix: separate provider
fixes #1793
2025-06-12 14:07:23 +08:00
Soulter
cee6efab19 Merge pull request #1783 from Kwicxy/fix
fix(readmeDialog): 修复了readme对话框内markdown渲染样式问题
2025-06-11 22:33:14 +08:00
Soulter
30f71cb550 Merge pull request #1791 from AstrBotDevs/feat-dify-user-param
Feature: supports dify user param
2025-06-11 22:26:07 +08:00
Soulter
771e755a78 feat: supports dify user param 2025-06-11 22:25:10 +08:00
Soulter
16ec462abd feat: WebUI ProviderPage 添加服务提供商会话隔离设置功能 2025-06-11 00:51:18 +08:00
Soulter
ca55465d3c chore: bump to 3.5.15 2025-06-11 00:32:46 +08:00
Soulter
7098c98dde fix: 修复 Windows 下部署项目时可能出现的 UnicodeDecodeError
fixes: #1548
2025-06-11 00:25:14 +08:00
Soulter
f56355da89 perf: 分段回复时,仅在输出的第一句话带上回复/引用
fixes: #521
2025-06-11 00:06:14 +08:00
Soulter
422160debd feat: 支持配置是否忽略@全体成员
fixes: #292
2025-06-10 23:55:50 +08:00
Soulter
8062cf406a fix: 优化配置完整性检查,同时保证配置项顺序的一致性 2025-06-10 23:30:58 +08:00
Soulter
0e802232ec feat: 新配置项,支持配置只@触发等待时是否回复 2025-06-10 23:29:45 +08:00
Soulter
f650a9205d perf(webui): 优化手机端的显示 2025-06-10 22:43:58 +08:00
Soulter
c85dbb2347 fix: 修复某些情况下,会话控制无效的问题 2025-06-10 22:26:11 +08:00
Soulter
a6a79128c8 chore: bump to v3.5.15 2025-06-10 22:18:05 +08:00
Soulter
42839627e8 fix: 修复在设置了 GitHub 加速地址后,插件无法更新的问题 2025-06-10 22:12:46 +08:00
Richard X.
e7f35098e4 fix(readmeDialog): Fix readme dialog markdown rendering over different appearances.
Fix readme dialog markdown rendering over different appearances.
2025-06-10 21:46:35 +08:00
Soulter
267e68a894 chore: bump docker image python version to 3.11 2025-06-10 21:40:20 +08:00
Soulter
b32b444438 Merge pull request #1776 from AstrBotDevs/feat-webchat-title
Feature: 支持重命名和自动生成 WebChat title;WebChat Route 和 UI 优化;支持 WebChatBox
2025-06-10 21:34:17 +08:00
Soulter
522d0f8313 chore: ts lint 2025-06-10 21:33:53 +08:00
Soulter
5715e5de67 chore: fix ts lint 2025-06-10 21:28:06 +08:00
Soulter
cc6b05e8b3 fix: remove fallback for returnUrl in AuthLogin.vue 2025-06-10 21:25:58 +08:00
Soulter
417747d5d0 feat: handle unauthorized access by redirecting to login page in ChatPage 2025-06-10 21:21:38 +08:00
Soulter
a34f439226 fix: update summary output condition and adjust max-width in ChatBoxPage 2025-06-10 18:36:26 +08:00
Soulter
b7ca014fd0 feat: enhance routing to support chatbox and improve path handling in ChatPage 2025-06-10 15:45:06 +08:00
Soulter
fa098d585a feat: add conversation detail routing and handle direct navigation in ChatPage 2025-06-10 15:39:26 +08:00
Soulter
c35a14e3ec fix: adjust padding and clean up unused code in ChatPage.vue 2025-06-10 15:06:33 +08:00
Soulter
60651736a5 feat: chatbox page 2025-06-10 15:02:18 +08:00
Soulter
581f9b7bd3 fix: typo fix
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-10 13:02:30 +08:00
Soulter
124eb04807 Merge pull request #1773 from AstrBotDevs/feat-seperate-provider
Feature: 支持对提供商会话隔离
2025-06-10 12:59:42 +08:00
Soulter
1d561da7fb style: clean code
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-06-10 12:59:20 +08:00
Soulter
16e3cd0784 fix: get_using_stt_provider is fetching using ProviderType.TEXT_TO_SPEECH but should use ProviderType.SPEECH_TO_TEXT for STT isolation.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-06-10 12:58:39 +08:00
Soulter
a6d91933dc feat: 支持自动生成webchat title 2025-06-10 10:58:49 +08:00
Raven95676
445c40f758 chore: update version 2025-06-10 10:29:31 +08:00
鸦羽
725a841a3b Merge pull request #1767 from AstrBotDevs/fix/1678
Fix: 调整Gemini原生工具启用行为
2025-06-10 08:22:41 +08:00
鸦羽
f77c453843 fix: clean code 2025-06-10 00:20:35 +00:00
Soulter
ba6718d5bc Merge pull request #1759 from Flartiny/dev
Feature: Add GreedyStr parameter support for commands
2025-06-10 00:06:34 +08:00
Soulter
cdb7a1b3fa style: merge else if into elif 2025-06-09 23:54:51 +08:00
Soulter
a03c79b89d style: use named expression 2025-06-09 23:51:54 +08:00
Soulter
98800d3426 fix(typo): "seperate_provider" -> "separate_provider" 2025-06-09 23:50:31 +08:00
Soulter
a616adaac4 fix: update provider manager set_provider() 2025-06-09 23:46:44 +08:00
Soulter
ffb5605c99 fix: default tts provider selection
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-09 23:38:15 +08:00
Soulter
621b556856 feat: 支持对提供商会话隔离
fixes: #1762 #602 #479
2025-06-09 23:33:00 +08:00
Soulter
a3ffecbb2a feat: add support for gemini_embedding provider 2025-06-09 14:43:05 +08:00
Soulter
ea64cebe2a ci: fix cloudflare r2 ci 2025-06-09 13:12:31 +08:00
鸦羽
e79487dd5f fix: add missing config 2025-06-09 05:03:15 +00:00
鸦羽
7fe1c1ec89 feat: add URL context feature to Gemini model configuration 2025-06-09 04:54:24 +00:00
Soulter
ab2bbff369 Merge pull request #1746 from Seayon/fix-wechat-at-message-parsing
 feat(wechatpadpro): 增强群聊消息中的@消息处理逻辑
2025-06-09 12:51:08 +08:00
Soulter
ec32825309 ci: fix cloudflare r2 upload 2025-06-09 12:41:20 +08:00
Soulter
fd0c182087 ci: fix ghcr token 2025-06-09 12:32:38 +08:00
Soulter
49fcff1daf 📦 release: v3.5.14 2025-06-09 12:31:02 +08:00
鸦羽
33b64ddf39 feat: enhance tool selection logic for Gemini model versions 2025-06-09 03:55:59 +00:00
Soulter
4c447aa648 perf: jwt token expire time change to 7 days 2025-06-09 11:52:48 +08:00
Soulter
ccbfc3d274 perf: 强化强制修改默认密码逻辑 2025-06-09 11:47:23 +08:00
Soulter
f83fe43bbb docs: alert 2025-06-09 10:12:09 +08:00
Seayon
19022d67f8 Merge branch 'master' into fix-wechat-at-message-parsing
# Conflicts:
#	astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
2025-06-09 09:30:09 +08:00
Soulter
58a815dd6b feat: ltm edge fact viewer 2025-06-08 20:34:41 +08:00
shuiping233
1ce95c473d fix : 在stage.py中专门对qq_official的会话控制器消息进行处理 2025-06-08 10:20:09 +08:00
shuiping233
eb365e398d fix: qq_official适配器使用SessionController(会话控制)功能时机器人回复消息无法发送到聊天平台 2025-06-08 10:20:09 +08:00
Soulter
bc9fe82860 Merge pull request #1737 from zhx8702/feat-wehcatpro-voice-adapter
feat: wechatpadpro 添加语音接收和发送的适配
2025-06-07 15:13:10 +08:00
Soulter
b3cd9bf2b9 Merge pull request #1743 from lvboda/hotfix-platform-page-iframe-style-issue-1741
fix(PlatformPage): iframe overflow style issue (#1741)
2025-06-07 15:11:16 +08:00
Soulter
c5c2b829ec Merge pull request #1758 from RC-CHN/master
fix: 修复 asyncio.wait_for 参数顺序错误
2025-06-07 15:08:37 +08:00
Flartiny
9713f96401 feat: Add greedy parameter support for commands 2025-06-07 10:32:31 +08:00
Ruochen
11f35ebf96 fix: 修复 asyncio.wait_for 参数顺序错误 2025-06-07 09:50:30 +08:00
Soulter
7d403aa181 fix: syntax error 2025-06-07 01:20:56 +08:00
Soulter
64af810a4a Merge pull request #1736 from RC-CHN/master
fix:修复了部分模型供应商测试不可用,但实际可用的问题。
2025-06-06 21:37:19 +08:00
Soulter
30821905af perf: remove default list param,fix dashscope_source contexts params 2025-06-06 21:36:01 +08:00
Seayon
a9dbff756b feat(wechatpadpro): 增强群聊消息中的@消息处理逻辑
添加对群聊消息中@机器人场景的精确识别和处理,提升了消息解析的准确性。
支持多种@格式的检测,包括 msg_source 和 push_content 的判断。
2025-06-06 16:53:31 +08:00
lvboda
a6aba10d3d fix(PlatformPage): iframe overflow style issue (#1741) 2025-06-06 15:18:35 +08:00
RC-CHN
9c276c37fe Update astrbot/dashboard/routes/config.py
测试过对于dashscope类型供应商添加上下文是必要的,否则需要改动其_remove_image_from_context方法。

Co-authored-by: Soulter  <37870767+Soulter@users.noreply.github.com>
2025-06-06 14:01:58 +08:00
Soulter
6ab6c0fd4c Merge pull request #1735 from Flartiny/dev
feat: able to parse repo url of specific branch
2025-06-06 12:44:51 +08:00
Soulter
b6b0fe3fff perf: 优化 GitHub 仓库解析和下载的逻辑 2025-06-06 12:02:46 +08:00
zhx
0d5825bda9 feat: wechatpadpro 添加语音接收和发送的适配 2025-06-06 10:30:06 +08:00
Ruochen
cdfb64631a fix:修复dashscope类型供应商测试问题,延长了设置超时时间,改进prompt工程,修复了控制台打印日志超时时间不符 2025-06-06 09:21:09 +08:00
Ruochen
d161c281c8 Merge branch 'master' of https://github.com/RC-CHN/AstrBot 2025-06-06 00:39:25 +08:00
Flartiny
8fed5bf2a1 feat: able to parse repo url of specific branch 2025-06-06 00:09:10 +08:00
Soulter
98d2e9bd27 chore: stage 2025-06-05 23:30:18 +08:00
Soulter
a03af55edd ci 2025-06-05 13:38:20 +08:00
Soulter
86e2fd9aee ci: publish to ghcr.io 2025-06-05 13:35:14 +08:00
Soulter
97bd0e5e58 Merge pull request #1730 from lxfight/master
feat: 添加插件更新后自动刷新插件列表功能
2025-06-05 11:39:32 +08:00
Soulter
ceaba21986 ci: publish to ghcr.io 2025-06-05 11:19:16 +08:00
Soulter
172a77d942 ci: publish to ghcr.io 2025-06-05 11:16:57 +08:00
Soulter
4f9d2d2a7d ci: publish to ghcr.io 2025-06-05 11:12:56 +08:00
lxfight
8c929f6e05 feat: 添加插件更新后自动刷新插件列表功能 2025-06-05 10:56:04 +08:00
Soulter
3319b71f5b Merge pull request #1721 from zhx8702/feat-add-wechat-47-49
feat: 添加wechatpadpro 消息类型47 49的适配
2025-06-04 22:52:29 +08:00
Soulter
46ec028a5b Merge pull request #1718 from Kwicxy/webui_enhancement
feat: webUI优化
2025-06-04 22:48:49 +08:00
Soulter
0ce0ef3e5c Merge pull request #1715 from Flartiny/dev
fix: residual configuration items after plugin configuration modification
2025-06-04 22:32:19 +08:00
kwicxy
375b071cb2 Merge remote-tracking branch 'origin/webui_enhancement' into webui_enhancement 2025-06-04 19:00:54 +08:00
kwicxy
29e1417ff2 feat: optional newUsername field in account editing 2025-06-04 18:59:38 +08:00
kwicxy
75db2bd366 fix(auth): bad localStorage keymapping 2025-06-04 18:58:53 +08:00
zhx
60ca1efbda feat: 添加wechatpadpro 消息类型47 49的适配 2025-06-04 14:36:16 +08:00
Richard X.
2692e4978b fix: remove console.log()
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-03 21:06:51 +08:00
Richard X.
91982eb002 Merge branch 'AstrBotDevs:master' into webui_enhancement 2025-06-03 20:36:51 +08:00
Soulter
bb1dec76fa remove: wechat qr code
hahaha
2025-06-03 20:22:08 +08:00
Flartiny
f618b8fcdc fix: residual configuration items after plugin configuration modification 2025-06-03 14:04:04 +08:00
Raven95676
9147cab75b fix: add additional routes for Alkaid knowledge base and long-term memory 2025-05-31 14:29:04 +08:00
Raven95676
5f07bcc8e6 feat: add Gemini embedding provider and update OpenAI provider to support timeout configuration 2025-05-31 14:13:58 +08:00
Soulter
705cf2ea1b docs(README.md): knowledge base 2025-05-31 14:08:01 +08:00
Soulter
42c4394484 ci: upload dashboard artifact to Cloudflare R2 when auto release 2025-05-31 13:50:40 +08:00
Soulter
221221a3c1 ci: upload dashboard artifact to Cloudflare R2 when auto release 2025-05-31 13:47:59 +08:00
Soulter
9564166297 perf: knowledge base displays console when installing 2025-05-31 11:52:24 +08:00
Soulter
f5cf3c3c8e Merge pull request #1691 from AstrBotDevs/perf-pip-async
Feature: 将插件依赖检查和 pip 安装方法改为异步,以提高性能和响应速度
2025-05-31 11:51:39 +08:00
Soulter
18f919fb6b perf: pip_main wrapped in asyncio.to_thread
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-05-31 11:47:29 +08:00
Soulter
0924835253 feat: 将插件依赖检查和 pip 安装方法改为异步,以提高性能和响应速度 2025-05-31 11:44:58 +08:00
Soulter
20d2e5c578 perf: 优化日志流发送频率,防止积压超过 buffer size 导致前端显示异常 2025-05-31 11:25:51 +08:00
Soulter
907801605c 📦 release: v3.5.13 2025-05-31 11:02:56 +08:00
Soulter
93bc684e8c feat: 添加旧版本提供商类型映射以兼容性支持 2025-05-31 11:00:59 +08:00
Soulter
a76c98d57e Merge pull request #1685 from RC-CHN/master
Feature: 添加测试文本生成供应商可用功能
2025-05-31 10:59:46 +08:00
Soulter
d937a800d0 fix: provider name 2025-05-31 10:46:35 +08:00
Soulter
d16f3a227f Merge branch 'master' into master 2025-05-31 10:46:15 +08:00
Soulter
80c9a3eeda style: code style 2025-05-31 09:25:18 +08:00
Soulter
e68173b451 feat: knowledge-base 2025-05-30 23:18:48 +08:00
Soulter
40c27d87f5 feat: knowledge-base 2025-05-30 23:18:19 +08:00
Soulter
3c13b5049d feat: 支持知识库的分片、重叠设置等 2025-05-30 23:00:37 +08:00
Soulter
8288d5e51f feat: embedding provider 2025-05-30 18:07:52 +08:00
Ruochen
6e1449900a feat: 优化单个 provider 可用性测试的回退逻辑 2025-05-30 15:35:13 +08:00
RC-CHN
4ffbb18ab4 Merge branch 'AstrBotDevs:master' into master 2025-05-30 15:12:33 +08:00
Ruochen
b27271b7a3 feat:添加测试文本生成供应商可用功能 2025-05-30 15:10:15 +08:00
Soulter
ebb6665f64 feat: add open_config parameter handling and configuration button in KnowledgeBase 2025-05-30 14:30:04 +08:00
Soulter
e4e5731ffd 📦 release: v3.5.13 2025-05-30 13:30:23 +08:00
Soulter
2ab5810f13 perf: improve transaction performance in vector db 2025-05-30 12:59:26 +08:00
Soulter
af934c5d09 fix: correct dimension typo and enhance API registration logic 2025-05-30 11:42:39 +08:00
Soulter
1e0cf7c112 fix: update ExtensionCard actions and add readme link functionality 2025-05-30 10:50:54 +08:00
Soulter
46859c93c9 perf: improve WebUI 2025-05-30 10:45:05 +08:00
Richard X.
ea1f9cb3b2 Merge branch 'AstrBotDevs:master' into master 2025-05-30 10:37:59 +08:00
Soulter
1641549016 perf: improve WebUI 2025-05-30 10:36:48 +08:00
鸦羽
716a5dbb8a chore: add nh3 to requirements.txt 2025-05-30 10:35:48 +08:00
鸦羽
af98cb11c5 fix: handle missing nh3 library in plugin.py 2025-05-30 10:35:48 +08:00
Soulter
9a4c2cf341 fix: downgrade faiss-cpu dependency to version 1.10.0 2025-05-30 10:21:31 +08:00
Soulter
2bc3bcd102 fix: handle missing nh3 library gracefully for README cleaning 2025-05-30 10:17:33 +08:00
Soulter
d6c663f79d fix: do not display change password dialog in demo mode 2025-05-30 10:09:09 +08:00
kwicxy
9ed86e5f53 feat: Name trim of extension list to improve readability 2025-05-30 09:37:21 +08:00
kwicxy
303e0bc037 fix(dashboard): MessageStat chart tooltips now supports dark appearance 2025-05-30 09:36:06 +08:00
Richard X.
2cc24019f9 Merge branch 'AstrBotDevs:master' into master 2025-05-30 08:50:27 +08:00
kwicxy
83ce774d19 chore: Extension marketplace scroll behaviour updated 2025-05-30 00:01:53 +08:00
Soulter
2b4ee13b5e Merge pull request #1672 from Kwicxy/master
Feat: 暗黑主题功能初步实现
2025-05-29 23:41:10 +08:00
kwicxy
3a964561f0 style: minor code style changes 2025-05-29 22:57:50 +08:00
kwicxy
6959f86632 feat: Using localStorage to remember user's theme setting. 2025-05-29 22:46:02 +08:00
Raven95676
537d373e10 fix: Fix potential XSS risk in plugin README content 2025-05-29 22:35:24 +08:00
Soulter
cceadf222c Merge pull request #1676 from AstrBotDevs/fix-chat-get-file-bug
Fix: fixed a potential vulnerability in `/api/chat/get_file` endpoint.
2025-05-29 21:41:55 +08:00
Soulter
cf5a4af623 chore: remove duplicated auth header 2025-05-29 21:19:39 +08:00
Raven95676
39aea11c22 perf: enhance file access security in get_file method
Co-authored-by: anka-afk <1350989414@qq.com>
2025-05-29 21:03:51 +08:00
Raven95676
c2f1227700 fix: add authorization header to file download request in ChatPage.vue 2025-05-29 19:57:11 +08:00
Soulter
900f14d37c 🐛 fix: fixed a potential vulnerability in /api/chat/get_file endpoint.
I have fixed a potential vulnerability in the `/api/chat/get_file` endpoint that could allow unauthorized access to files by ensuring the request has a jwt token.
2025-05-29 19:17:31 +08:00
kwicxy
598249b1d6 Merge remote-tracking branch 'origin/master' 2025-05-29 18:26:53 +08:00
Richard X.
7ed15bdf04 Merge branch 'AstrBotDevs:master' into master 2025-05-29 18:17:39 +08:00
Raven95676
2fc0ec0f72 fix: update route 2025-05-29 17:28:33 +08:00
kwicxy
5e9c2a669b fix: Various bug fixes and improvements 2025-05-29 16:41:03 +08:00
Soulter
b310521884 📦 release: v3.5.12 2025-05-29 15:55:25 +08:00
Soulter
288945bf7e chore: aiosqlite to requirements.txt 2025-05-29 15:48:21 +08:00
Soulter
4fc07cff36 📦 release: v3.5.12 2025-05-29 15:46:40 +08:00
kwicxy
b884fe0e86 fix: Various bug fixes 2025-05-29 09:31:29 +08:00
kwicxy
855858c236 fix: Changed default theme to PurpleTheme 2025-05-29 09:31:15 +08:00
kwicxy
c11a2a5419 feat: Login page darkened 2025-05-29 09:00:27 +08:00
kwicxy
773a6572af feat: WebUI Dark Appearance 2025-05-29 01:43:21 +08:00
kwicxy
88ad373c9b 深色主题切换功能初步实现 2025-05-29 01:28:45 +08:00
Soulter
51666464b9 Merge pull request #1667 from AstrBotDevs/fix-priority
Fix: plugin priority was not properly applied
2025-05-28 15:34:50 +08:00
Soulter
5af9cf2f52 Merge pull request #1668 from AstrBotDevs/refactor-segment
Refactor: 重构转发节点等消息段的 toDict 相关逻辑
2025-05-28 15:33:32 +08:00
Soulter
12c4ae4b10 perf: to_dict in the base class 2025-05-28 03:26:42 -04:00
Soulter
4e1bef414a perf: empty array 2025-05-28 03:25:19 -04:00
Soulter
e896c18644 perf: video 2025-05-28 15:12:21 +08:00
Soulter
c852685e74 fix: typeerror 2025-05-28 01:18:45 -04:00
Soulter
1e99797df8 refactor: improve message segment handle 2025-05-28 12:53:00 +08:00
Soulter
52a4c986a8 fix: update star_handlers_registry iteration in TelegramPlatformAdapter 2025-05-28 00:31:04 +08:00
Soulter
c501728204 fix: plugin priority
fixes: #1662
2025-05-28 00:23:02 +08:00
Soulter
6b067fa6a7 Merge pull request #1665 from Raven95676/master
fix(telegram): 支持长消息分段发送并优化消息编辑逻辑
2025-05-27 23:39:14 +08:00
Soulter
a1cd5c53a9 chore: add comments 2025-05-27 23:38:35 +08:00
Soulter
a46d487e03 Merge pull request #1644 from RC-CHN/master
fix:为llm和model和provider指令添加了管理员权限检查
2025-05-27 23:25:40 +08:00
Raven95676
3deb6d3ab3 fix: clean code 2025-05-27 20:52:40 +08:00
Raven95676
af34cdd5d2 fix(telegram): 支持长消息分段发送并优化消息编辑逻辑 2025-05-27 20:15:16 +08:00
Soulter
6e1393235a 🐛 fix: provider command error 2025-05-27 17:20:57 +08:00
Soulter
343e0b54b9 feat: MCP supports Streamable HTTP transport method
fixes: #1637 #1342
2025-05-27 15:39:02 +08:00
Soulter
ecb70cb6f7 feat: add support for custom headers in SSE client configuration
fixes: #1659
2025-05-27 15:05:42 +08:00
Soulter
ca50618af6 perf: load providers when llm config is off and rebooting astrbot
fixes: #1466
2025-05-27 15:01:58 +08:00
Soulter
29c07ba83e 🐛 fix: function tools argument type issue
fixes: #1454
2025-05-27 13:54:16 +08:00
Ruochen
45fbb83a9f fix:为llm和model和provider指令添加了管理员权限检查 2025-05-25 00:24:20 +08:00
Soulter
ae7ba2df25 Merge pull request #1553 from Raven95676/Feature/use-file-service
Feature: T2I、TTS使用文件服务
2025-05-23 17:10:38 +08:00
Soulter
c3ef57cc32 Merge pull request #1588 from Zhenyi-Wang/feat/extend-wechatpadpro-for-timetask
feat: wechatpadpro对接获取联系人信息的2个接口
2025-05-23 17:02:54 +08:00
Soulter
7bb4ca5a14 perf: code quality 2025-05-23 17:01:57 +08:00
Soulter
063783d81d Merge pull request #1599 from HendricksJudy/master
Fix initialization bug and improve plugin utility
2025-05-23 16:58:25 +08:00
Soulter
42116c9b65 Merge pull request #1631 from AstrBotDevs/feat/alkaid
[WIP] Feature: 提供 AstrBot 后端服务插件接口、试验性嵌入式知识库(Alkaid)、移除不必要的包
2025-05-23 16:57:04 +08:00
Soulter
a36e11973d perf: code quality
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:56:09 +08:00
Soulter
5125568ea2 perf: 交换 if/else 表达式的分支以删除否定
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:49:08 +08:00
Soulter
0fa164e50d perf: 使用 HTML autocomplete 属性禁用浏览器自动填充
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:48:29 +08:00
Soulter
cf814e81ee chore: delete alkaid route 2025-05-23 16:41:33 +08:00
Soulter
43a45f18ce perf: knowledgebase delete 2025-05-23 15:50:10 +08:00
Soulter
ad51381063 perf: 动态路由注册 2025-05-23 15:18:16 +08:00
Soulter
0b0e4ce904 remove: vpet 2025-05-23 14:22:34 +08:00
Soulter
6a3e04d688 Merge remote-tracking branch 'origin/master' into feat/alkaid 2025-05-23 14:22:06 +08:00
Soulter
4107a17370 chore: add faiss and aiosqlite deps 2025-05-23 14:04:13 +08:00
Soulter
06b4d8f169 perf: vecdb similarity type 2025-05-23 13:45:00 +08:00
Soulter
1c0c820746 remove: loguru 2025-05-23 13:42:17 +08:00
Soulter
d061403a28 remove: loguru 2025-05-23 13:39:20 +08:00
Soulter
5c092321a6 feat: faiss vecdb implementation
remove: old knowledgedb deps
2025-05-23 13:16:24 +08:00
Soulter
bdd3f61c1f remove: old knowledge db impl and useless impls 2025-05-23 11:43:26 +08:00
Raven95676
8023557d6e feat: 强制修改默认密码 2025-05-22 18:30:29 +08:00
Raven95676
074b0ced7a perf: 移除冗余逻辑
经与@Soulter确认,metadata.yaml是必须有的文件,故在建议下删除
2025-05-22 18:21:41 +08:00
Soulter
3864b1ac9b Merge pull request #1620 from YOOkoishi/feat-add-volcengine-support
🐛 fix : 修改description,适配火山引擎基础的语音合成
2025-05-22 17:52:39 +08:00
YOO_koishi
6e9b43457d Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-22 08:09:59 +08:00
YOO_koishi
ca1aec8920 🐛 fix : 修改description,适配火山引擎基础的语音合成 2025-05-22 08:09:36 +08:00
Soulter
acac580862 feat: ltm and kb 2025-05-20 20:50:22 +08:00
Soulter
673e1b2980 remove: vpet 2025-05-20 15:03:40 +08:00
HendricksJudy
39c8cfeda5 Merge pull request #2 from HendricksJudy/codex/fix-core-initialization-failure-handling-in-initialloader
Fix initialization bug and improve plugin utility
2025-05-19 01:43:22 -07:00
HendricksJudy
f38a329be5 Fix initialization and plugin download 2025-05-19 01:43:07 -07:00
Soulter
d2379da478 chore: use d3 2025-05-18 16:43:47 +08:00
Soulter
0f64981b20 feat: alkaid long term memory graph visualize 2025-05-18 13:26:44 +08:00
Soulter
10270b5595 feat: alkaid framework and supports to customize webapi endpoint 2025-05-17 15:38:51 +08:00
Zhenyi Wang
f7458572ed feat: wechatpadpro对接获取联系人信息接口 2025-05-17 15:31:12 +08:00
Raven95676
c5ccc1a084 feat(Video): 增加视频消息组件的文件转换和注册功能 2025-05-15 09:50:27 +08:00
Raven95676
e6981290bc perf: 优化 Record 对象的文件和 URL 字段赋值逻辑 2025-05-14 20:05:38 +08:00
Raven95676
75c3d8abbd feat(t2i): 为本地文本转图像功能添加文件服务支持 2025-05-14 19:28:23 +08:00
Raven95676
d88683f498 feat(tts): 增加使用文件服务提供 TTS 语音文件的功能 2025-05-14 19:28:23 +08:00
Raven95676
40b9aa3a4c style: format code 2025-05-14 19:15:13 +08:00
294 changed files with 23688 additions and 6963 deletions

View File

@@ -1,40 +1,56 @@
name: '🥳 发布插件'
title: "[Plugin] 插件名"
name: 🥳 发布插件
description: 提交插件到插件市场
labels: [ "plugin-publish" ]
title: "[Plugin] 插件名"
labels: ["plugin-publish"]
assignees: []
body:
- type: markdown
attributes:
value: |
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
欢迎发布插件到插件市场!
- type: textarea
- type: markdown
attributes:
label: 插件仓库
description: 插件的 GitHub 仓库链接
placeholder: >
如 https://github.com/Soulter/astrbot-github-cards
- type: textarea
attributes:
label: 描述
value: |
插件名:
插件作者:
插件简介:
支持的消息平台:(必填,如 QQ、微信、飞书)
标签:(可选)
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
## 插件基本信息
- type: checkboxes
请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。
不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦!
- type: textarea
id: plugin-info
attributes:
label: Code of Conduct
options:
- label: >
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
label: 插件信息
description: 请在下方代码块中填写您的插件信息确保反引号包裹了JSON
value: |
```json
{
"name": "插件名",
"desc": "插件介绍",
"author": "作者名",
"repo": "插件仓库链接",
"tags": [],
"social_link": ""
}
```
validations:
required: true
- type: markdown
attributes:
value: "❤️"
value: |
## 检查
- type: checkboxes
id: checks
attributes:
label: 插件检查清单
description: 请确认以下所有项目
options:
- label: 我的插件经过完整的测试
required: true
- label: 我的插件不包含恶意代码
required: true
- label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
required: true

63
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,63 @@
# AstrBot Development Instructions
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
## Working Effectively
### Bootstrap and Install Dependencies
- **Python 3.10+ required** - Check `.python-version` file
- Install UV package manager: `pip install uv`
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
- Create required directories: `mkdir -p data/plugins data/config data/temp`
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
- Navigate to dashboard: `cd dashboard`
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
- Dashboard creates optimized production build in `dashboard/dist/`
### Testing
- Do not generate test files for now.
### Code Quality and Linting
- Install ruff linter: `uv add --dev ruff`
- Check code style: `uv run ruff check .` -- takes <1 second
- Check formatting: `uv run ruff format --check .` -- takes <1 second
- Fix formatting: `uv run ruff format .`
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
### Common Issues and Workarounds
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
## CI/CD Integration
- GitHub Actions workflows in `.github/workflows/`
- Docker builds supported via `Dockerfile`
- Pre-commit hooks enforce ruff formatting and linting
## Docker Support
- Primary deployment method: `docker run soulter/astrbot:latest`
- Compose file available: `compose.yml`
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
- Volume mount required: `./data:/AstrBot/data`
## Multi-language Support
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
- UI supports internationalization
- Default language is Chinese
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.

13
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,13 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly

View File

@@ -13,7 +13,7 @@ jobs:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: Dashboard Build
run: |
@@ -24,6 +24,36 @@ jobs:
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Upload to Cloudflare R2
env:
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
R2_BUCKET_NAME: "astrbot"
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
VERSION_TAG: ${{ github.ref_name }}
run: |
echo "Installing rclone..."
curl https://rclone.org/install.sh | sudo bash
echo "Configuring rclone remote..."
mkdir -p ~/.config/rclone
cat <<EOF > ~/.config/rclone/rclone.conf
[r2]
type = s3
provider = Cloudflare
access_key_id = $R2_ACCESS_KEY_ID
secret_access_key = $R2_SECRET_ACCESS_KEY
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
EOF
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
@@ -40,10 +70,10 @@ jobs:
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.10'

View File

@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -8,6 +8,7 @@ on:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
@@ -16,30 +17,29 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
pip install pytest pytest-asyncio pytest-cov
pip install --editable .
- name: Run tests
run: |
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
mkdir -p data/plugins
mkdir -p data/config
mkdir -p data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
pytest --cov=. -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1,13 +1,17 @@
name: AstrBot Dashboard CI
on: [push]
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: npm install, build
run: |

View File

@@ -11,24 +11,42 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: 拉取源码
uses: actions/checkout@v3
- name: Pull The Codes
uses: actions/checkout@v5
with:
fetch-depth: 1
fetch-depth: 0 # Must be 0 so we can fetch tags
- name: 设置 QEMU
- name: Get latest tag (only on manual trigger)
id: get-latest-tag
if: github.event_name == 'workflow_dispatch'
run: |
tag=$(git describe --tags --abbrev=0)
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
- name: Checkout to latest tag (only on manual trigger)
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Set QEMU
uses: docker/setup-qemu-action@v3
- name: 设置 Docker Buildx
- name: Set Docker Buildx
uses: docker/setup-buildx-action@v3
- name: 登录到 DockerHub
- name: Log in to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: 构建和推送 Docker hub
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: Soulter
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Docker to DockerHub and Github GHCR
uses: docker/build-push-action@v6
with:
context: .
@@ -36,8 +54,9 @@ jobs:
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
ghcr.io/soulter/astrbot:latest
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v5
- uses: actions/stale@v9
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'

View File

@@ -1,4 +1,4 @@
FROM python:3.10-slim
FROM python:3.11-slim
WORKDIR /AstrBot
COPY . /AstrBot/

121
README.md
View File

@@ -1,6 +1,6 @@
<p align="center">
![yjtp](https://github.com/user-attachments/assets/dcc74009-c57e-4b66-9ae3-0a81fc001255)
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
@@ -16,7 +16,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7日消息量&cacheSeconds=3600&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
@@ -27,49 +27,50 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
> [!NOTE]
>
> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,`v3.5.10` 已经支持接入 WeChatPadPro 替换 gewechat 方式。详见文档 [WeChatPadPro](https://astrbot.app/deploy/platform/wechat/wechatpadpro.html)
## ✨ 近期更新
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能
> [!NOTE]
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper
2. **多消息平台接入**。支持接入 QQOneBot、QQ 频道、微信Gewechat、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP]
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富
5. **WebUI**。可视化配置和管理机器人,功能齐全
## ✨ 使用方式
#### Docker 部署
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### 1Panel 部署
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。
#### 在 雨云 上部署
AstrBot 已由雨云官方上架至云应用平台,可一键部署。
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### 在 Replit 上部署
社区贡献的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### Windows 一键安装器部署
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### 宝塔面板部署
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### CasaOS 部署
社区贡献的部署方式。
@@ -93,42 +94,33 @@ git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
或者,直接通过 uvx 安装 AstrBot
```bash
mkdir astrbot && cd astrbot
uvx astrbot init
# uvx astrbot run
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ⚡ 消息平台支持情况
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信个人号 | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| Telegram | ✔ | 私聊、群聊 | 文字、图片 |
| 企业微信 | ✔ | 私聊 | 文字、图片、语音 |
| 微信客服 | ✔ | 私聊 | 文字、图片 |
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
| 平台 | 支持性 |
| -------- | ------- |
| QQ(官方机器人接口) | ✔ |
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企业微信 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
| 钉钉 | ✔ |
| Slack | ✔ |
| Discord | |
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | |
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | |
| 微信对话开放平台 | 🚧 |
| WhatsApp | 🚧 |
| 小爱音响 | 🚧 |
## ⚡ 提供商支持情况
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
@@ -136,6 +128,8 @@ uvx astrbot init
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | |
@@ -143,6 +137,7 @@ uvx astrbot init
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
@@ -171,7 +166,6 @@ pre-commit install
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
## ✨ Demo
@@ -211,7 +205,7 @@ _✨ WebUI ✨_
此外,本项目的诞生离不开以下开源项目:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ)
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History
@@ -225,11 +219,8 @@ _✨ WebUI ✨_
</div>
## Disclaimer
![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097)
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
_私は、高性能ですから!_

View File

@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
## ✨ 主な機能
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechatFeishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
## 免責事項
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
2. WeChat個人アカウントのデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください
<!-- ## ✨ ATRI [ベータテスト]
@@ -165,6 +164,4 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
4. TTS
-->
_私は、高性能ですから!_

View File

@@ -1 +1 @@
__version__ = "3.5.8"
__version__ = "3.5.23"

View File

@@ -3,7 +3,6 @@ import tempfile
import httpx
import yaml
import re
from enum import Enum
from io import BytesIO
from pathlib import Path
@@ -59,6 +58,15 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(download_url)
if (
resp.status_code == 404
and "archive/refs/heads/master.zip" in download_url
):
alt_url = download_url.replace("master.zip", "main.zip")
click.echo("master 分支不存在,尝试下载 main 分支")
resp = client.get(alt_url)
resp.raise_for_status()
else:
resp.raise_for_status()
zip_content = BytesIO(resp.content)
with ZipFile(zip_content) as z:
@@ -91,39 +99,6 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
return {}
def extract_py_metadata(plugin_dir: Path) -> dict:
"""从 Python 文件中提取插件元数据
Args:
plugin_dir: 插件目录路径
Returns:
dict: 包含元数据的字典,如果提取失败则返回空字典
"""
# 检查 main.py 或与目录同名的 py 文件
for pattern in ["main.py", f"{plugin_dir.name}.py"]:
for py_file in plugin_dir.glob(pattern):
try:
content = py_file.read_text(encoding="utf-8")
register_match = re.search(
r'@register_star\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"?([^")]+)"?)?\s*\)',
content,
)
if register_match:
# 映射匹配组到元数据键
metadata = {}
keys = ["name", "author", "desc", "version", "repo"]
for i, key in enumerate(keys):
if i + 1 <= len(
register_match.groups()
) and register_match.group(i + 1):
metadata[key] = register_match.group(i + 1)
return metadata
except Exception as e:
click.echo(f"读取 {py_file} 失败: {e}", err=True)
return {}
def build_plug_list(plugins_dir: Path) -> list:
"""构建插件列表,包含本地和在线插件信息
@@ -139,20 +114,16 @@ def build_plug_list(plugins_dir: Path) -> list:
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
# 从不同来源加载元数据
# 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
# 如果元数据不完整,尝试从 Python 文件提取
if not metadata or not all(
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
py_metadata = extract_py_metadata(plugin_dir)
# 合并元数据,保留已有的值
for key, value in py_metadata.items():
if key not in metadata or not metadata[key]:
metadata[key] = value
# 如果成功提取元数据,添加到结果列表
if metadata:
result.append(
{
"name": str(metadata.get("name", "")),

View File

@@ -13,7 +13,6 @@ from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
DEMO_MODE = os.getenv("DEMO_MODE", False)
astrbot_config = AstrBotConfig()
@@ -29,6 +28,3 @@ pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None),
)
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)

View File

@@ -43,6 +43,7 @@ class AstrBotConfig(dict):
"""不存在时载入默认配置"""
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
@@ -82,23 +83,61 @@ class AstrBotConfig(dict):
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
"""检查配置完整性,如果有新的配置项则返回 True"""
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
has_new = False
# 创建一个新的有序字典以保持参考配置的顺序
new_conf = {}
# 先按照参考配置的顺序添加配置项
for key, value in refer_conf.items():
if key not in conf:
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,插入默认值 {value}")
# 配置项不存在,插入默认值
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
conf[key] = value
new_conf[key] = value
has_new = True
else:
if conf[key] is None:
conf[key] = value
# 配置项为 None使用默认值
new_conf[key] = value
has_new = True
elif isinstance(value, dict):
has_new |= self.check_config_integrity(
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
new_conf[key] = value
has_new = True
else:
# 递归检查并同步顺序
child_has_new = self.check_config_integrity(
value, conf[key], path + "." + key if path else key
)
new_conf[key] = conf[key]
has_new |= child_has_new
else:
# 直接使用现有配置
new_conf[key] = conf[key]
# 检查是否存在参考配置中没有的配置项
for key in list(conf.keys()):
if key not in refer_conf:
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
has_new = True
# 顺序不一致也算作变更
if list(conf.keys()) != list(new_conf.keys()):
if path:
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
else:
logger.info("检查到配置项顺序不一致,已重新排序")
has_new = True
# 更新原始配置
conf.clear()
conf.update(new_conf)
return has_new
def save_config(self, replace_config: Dict = None):

File diff suppressed because it is too large Load Diff

View File

@@ -88,7 +88,10 @@ class ConversationManager:
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(
self, unified_msg_origin: str, conversation_id: str
self,
unified_msg_origin: str,
conversation_id: str,
create_if_not_exists: bool = False,
) -> Conversation:
"""获取会话的对话
@@ -98,6 +101,13 @@ class ConversationManager:
Returns:
conversation (Conversation): 对话对象
"""
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
if not conv and create_if_not_exists:
# 如果对话不存在且需要创建,则新建一个对话
conversation_id = await self.new_conversation(unified_msg_origin)
return self.db.get_conversation_by_user_id(
unified_msg_origin, conversation_id
)
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:

View File

@@ -1,6 +1,6 @@
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
工作流程:
@@ -28,7 +28,6 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
@@ -37,7 +36,7 @@ from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle:
"""
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
@@ -47,14 +46,17 @@ class AstrBotCoreLifecycle:
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
# 根据环境变量设置代理
# 设置代理
if self.astrbot_config.get("http_proxy", ""):
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
if proxy := os.environ.get("https_proxy"):
logger.debug(f"Using proxy: {proxy}")
os.environ["no_proxy"] = "localhost"
async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
@@ -73,9 +75,6 @@ class AstrBotCoreLifecycle:
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
# 初始化知识库管理器
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
@@ -87,7 +86,6 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.knowledge_db_manager,
)
# 初始化插件管理器

View File

@@ -1,113 +0,0 @@
import json
import aiosqlite
import os
from typing import Any
from .plugin_storage import PluginStorage
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DBPATH = os.path.join(get_astrbot_data_path(), "plugin_data", "sqlite", "plugin_data.db")
class SQLitePluginStorage(PluginStorage):
"""插件数据的 SQLite 存储实现类。
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
所有数据以 (plugin, key) 作为复合主键进行索引。
"""
_instance = None # Standalone instance of the class
_db_conn = None
db_path = None
def __new__(cls):
"""
创建或获取 SQLitePluginStorage 的单例实例。
如果实例已存在,则返回现有实例;否则创建一个新实例。
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
"""
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
if cls._instance is None:
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
cls._instance.db_path = DBPATH
return cls._instance
async def _init_db(self):
"""初始化数据库连接(只执行一次)"""
if SQLitePluginStorage._db_conn is None:
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
await self._setup_db()
async def _setup_db(self):
"""
异步初始化数据库。
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
其中 plugin 和 key 组合作为主键。
"""
await self._db_conn.execute("""
CREATE TABLE IF NOT EXISTS plugin_data (
plugin TEXT,
key TEXT,
value TEXT,
PRIMARY KEY (plugin, key)
)
""")
await self._db_conn.commit()
async def set(self, plugin: str, key: str, value: Any):
"""
异步存储数据。
将指定插件的键值对存入数据库,如果键已存在则更新值。
值会被序列化为 JSON 字符串后存储。
Args:
plugin: 插件标识符
key: 数据键名
value: 要存储的数据值(任意类型,将被 JSON 序列化)
"""
await self._init_db()
await self._db_conn.execute(
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
(plugin, key, json.dumps(value)),
)
await self._db_conn.commit()
async def get(self, plugin: str, key: str) -> Any:
"""
异步获取数据。
从数据库中获取指定插件和键名对应的值,
返回的值会从 JSON 字符串反序列化为原始数据类型。
Args:
plugin: 插件标识符
key: 数据键名
Returns:
Any: 存储的数据值,如果未找到则返回 None
"""
await self._init_db()
async with self._db_conn.execute(
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
(plugin, key),
) as cursor:
row = await cursor.fetchone()
return json.loads(row[0]) if row else None
async def delete(self, plugin: str, key: str):
"""
异步删除数据。
从数据库中删除指定插件和键名对应的数据项。
Args:
plugin: 插件标识符
key: 要删除的数据键名
"""
await self._init_db()
await self._db_conn.execute(
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
)
await self._db_conn.commit()

View File

@@ -11,7 +11,9 @@ class SQLiteDatabase(BaseDatabase):
super().__init__()
self.db_path = db_path
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
with open(
os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8"
) as f:
sql = f.read()
# 初始化数据库

View File

@@ -0,0 +1,46 @@
import abc
from dataclasses import dataclass
@dataclass
class Result:
similarity: float
data: dict
class BaseVecDB:
async def initialize(self):
"""
初始化向量数据库
"""
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
top_k (int): 返回的最相似文档的数量
Returns:
List[Result]: 查询结果
"""
...
@abc.abstractmethod
async def delete(self, doc_id: str) -> bool:
"""
删除指定文档。
Args:
doc_id (str): 要删除的文档 ID
Returns:
bool: 删除是否成功
"""
...

View File

@@ -0,0 +1,3 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]

View File

@@ -0,0 +1,121 @@
import aiosqlite
import os
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
async def get_documents(self, metadata_filters: dict, ids: list = None):
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
"""
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
await self.connection.commit()
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
Returns:
list: A list of user IDs.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
return [row[0] for row in rows]
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
Args:
row (tuple): The row to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": row[0],
"doc_id": row[1],
"text": row[2],
"metadata": row[3],
"created_at": row[4],
"updated_at": row[5],
}
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None

View File

@@ -0,0 +1,59 @@
try:
import faiss
except ModuleNotFoundError:
raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。"
)
import os
import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str = None):
self.dimension = dimension
self.path = path
self.index = None
if path and os.path.exists(path):
self.index = faiss.read_index(path)
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
Args:
vector (np.ndarray): 要插入的向量
id (int): 向量的ID
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.storage[id] = vector
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
"""搜索最相似的向量
Args:
vector (np.ndarray): 查询向量
k (int): 返回的最相似向量的数量
Returns:
tuple: (距离, 索引)
"""
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def save_index(self):
"""保存索引
Args:
path (str): 保存索引的路径
"""
faiss.write_index(self.index, self.path)

View File

@@ -0,0 +1,17 @@
-- 创建文档存储表,包含 faiss 中文档的 id文档文本create_atupdated_at
CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id TEXT NOT NULL,
text TEXT NOT NULL,
metadata TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE documents
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
ALTER TABLE documents
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
CREATE INDEX idx_documents_user_id ON documents(user_id);
CREATE INDEX idx_documents_group_id ON documents(group_id);

View File

@@ -0,0 +1,117 @@
import uuid
import json
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
class FaissVecDB(BaseVecDB):
"""
A class to represent a vector database.
"""
def __init__(
self,
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
self.embedding_provider = embedding_provider
self.document_storage = DocumentStorage(doc_store_path)
self.embedding_storage = EmbeddingStorage(
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def retrieve(
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
metadata_filters (dict): 元数据过滤器
Returns:
List[Result]: 查询结果
"""
embedding = await self.embedding_provider.get_embedding(query)
scores, indices = await self.embedding_storage.search(
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
# TODO: rerank
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
scores[0] = 1.0 - (scores[0] / 2.0)
# NOTE: maybe the size is less than k.
fetched_docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters or {}, ids=indices[0]
)
if not fetched_docs:
return []
result_docs = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
pos = idx_pos.get(indice_idx)
if pos is None:
continue
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
return result_docs[:k]
async def delete(self, doc_id: int):
"""
删除一条文档
"""
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
async def close(self):
await self.document_storage.close()
async def count_documents(self) -> int:
"""
计算文档数量
"""
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0

View File

@@ -2,6 +2,8 @@ import asyncio
import os
import uuid
import time
from urllib.parse import urlparse, unquote
import platform
class FileTokenService:
@@ -15,7 +17,9 @@ class FileTokenService:
async def _cleanup_expired_tokens(self):
"""清理过期的令牌"""
now = time.time()
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
expired_tokens = [
token for token, (_, expire) in self.staged_files.items() if expire < now
]
for token in expired_tokens:
self.staged_files.pop(token, None)
@@ -32,15 +36,35 @@ class FileTokenService:
Raises:
FileNotFoundError: 当路径不存在时抛出
"""
# 处理 file:///
try:
parsed_uri = urlparse(file_path)
if parsed_uri.scheme == "file":
local_path = unquote(parsed_uri.path)
if platform.system() == "Windows" and local_path.startswith("/"):
local_path = local_path[1:]
else:
# 如果没有 file:/// 前缀,则认为是普通路径
local_path = file_path
except Exception:
# 解析失败时,按原路径处理
local_path = file_path
async with self.lock:
await self._cleanup_expired_tokens()
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
if not os.path.exists(local_path):
raise FileNotFoundError(
f"文件不存在: {local_path} (原始输入: {file_path})"
)
file_token = str(uuid.uuid4())
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
self.staged_files[file_token] = (file_path, expire_time)
expire_time = time.time() + (
timeout if timeout is not None else self.default_timeout
)
# 存储转换后的真实路径
self.staged_files[file_token] = (local_path, expire_time)
return file_token
async def handle_file(self, file_token: str) -> str:

View File

@@ -26,13 +26,14 @@ class InitialLoader:
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
core_task = []
try:
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
except Exception as e:
logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
return
core_task = core_lifecycle.start()
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event

View File

@@ -96,8 +96,6 @@ class LogBroker:
Queue: 订阅者的队列, 可用于接收日志消息
"""
q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
self.subscribers.append(q)
return q

View File

@@ -102,6 +102,10 @@ class BaseMessageComponent(BaseModel):
data[k] = v
return {"type": self.type.lower(), "data": data}
async def to_dict(self) -> dict:
# 默认情况下,回退到旧的同步 toDict()
return self.toDict()
class Plain(BaseMessageComponent):
type: ComponentType = "Plain"
@@ -118,6 +122,12 @@ class Plain(BaseMessageComponent):
self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
)
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
async def to_dict(self):
return {"type": "text", "data": {"text": self.text}}
class Face(BaseMessageComponent):
type: ComponentType = "Face"
@@ -235,9 +245,6 @@ class Video(BaseMessageComponent):
path: T.Optional[str] = ""
def __init__(self, file: str, **_):
# for k in _.keys():
# if k == "c" and _[k] not in [2, 3]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -250,6 +257,70 @@ class Video(BaseMessageComponent):
return Video(file=url, **_)
raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL则会自动进行下载
Returns:
str: 视频的本地路径,以绝对路径表示。
"""
url = self.file
if url and url.startswith("file:///"):
return url[8:]
elif url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
else:
raise Exception(f"download failed: {url}")
elif os.path.exists(url):
return os.path.abspath(url)
else:
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
"""
将视频注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = self.file
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated video file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "video",
"data": {
"file": payload_file,
},
}
class At(BaseMessageComponent):
type: ComponentType = "At"
@@ -259,6 +330,12 @@ class At(BaseMessageComponent):
def __init__(self, **_):
super().__init__(**_)
def toDict(self):
return {
"type": "at",
"data": {"qq": str(self.qq)},
}
class AtAll(At):
qq: str = "all"
@@ -514,27 +591,51 @@ class Node(BaseMessageComponent):
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
content: T.Optional[list[BaseMessageComponent]] = []
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0 # 忽略
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
if isinstance(content, list):
_content = None
if all(isinstance(item, Node) for item in content):
_content = [node.toDict() for node in content]
else:
_content = ""
for chain in content:
_content += chain.toString()
content = _content
elif isinstance(content, Node):
content = content.toDict()
def __init__(self, content: list[BaseMessageComponent], **_):
if isinstance(content, Node):
# back
content = [content]
super().__init__(content=content, **_)
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
async def to_dict(self):
data_content = []
for comp in self.content:
if isinstance(comp, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await comp.convert_to_base64()
data_content.append(
{
"type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"},
}
)
elif isinstance(comp, Plain):
# For Plain segments, we need to handle the plain differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, File):
# For File segments, we need to handle the file differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, (Node, Nodes)):
# For Node segments, we recursively convert them to dict
d = await comp.to_dict()
data_content.append(d)
else:
d = comp.toDict()
data_content.append(d)
return {
"type": "node",
"data": {
"user_id": str(self.uin),
"nickname": self.name,
"content": data_content,
},
}
class Nodes(BaseMessageComponent):
@@ -545,12 +646,20 @@ class Nodes(BaseMessageComponent):
super().__init__(nodes=nodes, **_)
def toDict(self):
"""Deprecated. Use to_dict instead"""
ret = {
"messages": [],
}
for node in self.nodes:
d = node.toDict()
d["data"]["uin"] = str(node.uin) # 转为字符串
ret["messages"].append(d)
return ret
async def to_dict(self):
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
d = await node.to_dict()
ret["messages"].append(d)
return ret
@@ -723,6 +832,26 @@ class File(BaseMessageComponent):
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = await self.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "file",
"data": {
"name": self.name,
"file": payload_file,
},
}
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"

View File

@@ -24,6 +24,8 @@ class MessageChain:
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
type: Optional[str] = None
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。
@@ -98,6 +100,15 @@ class MessageChain:
self.chain.append(Image.fromFileSystem(path))
return self
def base64_image(self, base64_str: str):
"""添加一条图片消息base64 编码字符串)到消息链 `chain` 中。
Example:
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
"""
self.chain.append(Image.fromBase64(base64_str))
return self
def use_t2i(self, use_t2i: bool):
"""设置是否使用文本转图片服务。
@@ -157,7 +168,7 @@ class ResultContentType(enum.Enum):
"""普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
STREAMING_FINISH= enum.auto()
STREAMING_FINISH = enum.auto()
"""流式输出完成"""

View File

@@ -1,22 +1,24 @@
from astrbot.core.message.message_event_result import (
MessageEventResult,
EventResultType,
MessageEventResult,
)
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"SessionStatusCheckStage", # 检查会话是否整体启用
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
@@ -29,6 +31,7 @@ STAGES_ORDER = [
__all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"SessionStatusCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PlatformCompatibilityStage",

View File

@@ -1,6 +1,14 @@
import inspect
import traceback
import typing as T
from dataclasses import dataclass
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star import PluginManager
from astrbot.api import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
@dataclass
@@ -9,3 +17,97 @@ class PipelineContext:
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
async def call_event_hook(
self,
event: AstrMessageEvent,
hook_type: EventType,
*args,
) -> bool:
"""调用事件钩子函数
Returns:
bool: 如果事件被终止,返回 True
"""
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, *args)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return event.is_stopped()
async def call_handler(
self,
event: AstrMessageEvent,
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 事件对象
handler (Awaitable): 事件处理函数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret
if not _has_yielded:
# 如果这个异步生成器没有执行到 yield 分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -43,10 +43,10 @@ class PreProcessStage(Stage):
# STT
if self.stt_settings.get("enable", False):
# TODO: 独立
stt_provider = (
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
)
if stt_provider:
ctx = self.plugin_manager.context
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:

View File

@@ -0,0 +1,58 @@
import abc
import typing as T
from dataclasses import dataclass
from astrbot.core.provider.entities import LLMResponse
from ....message.message_event_result import MessageChain
from enum import Enum, auto
class AgentState(Enum):
"""Agent 状态枚举"""
IDLE = auto() # 初始状态
RUNNING = auto() # 运行中
DONE = auto() # 完成
ERROR = auto() # 错误状态
class AgentResponseData(T.TypedDict):
chain: MessageChain
@dataclass
class AgentResponse:
type: str
data: AgentResponseData
class BaseAgentRunner:
@abc.abstractmethod
async def reset(self) -> None:
"""
Reset the agent to its initial state.
This method should be called before starting a new run.
"""
...
@abc.abstractmethod
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
"""
Process a single step of the agent.
"""
...
@abc.abstractmethod
def done(self) -> bool:
"""
Check if the agent has completed its task.
Returns True if the agent is done, False otherwise.
"""
...
@abc.abstractmethod
def get_final_llm_resp(self) -> LLMResponse | None:
"""
Get the final observation from the agent.
This method should be called after the agent is done.
"""
...

View File

@@ -0,0 +1,306 @@
import sys
import traceback
import typing as T
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
from ...context import PipelineContext
from astrbot.core.provider.provider import Provider
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageChain,
)
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from mcp.types import (
TextContent,
ImageContent,
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# TODO:
# 1. 处理平台不兼容的处理器
class ToolLoopAgent(BaseAgentRunner):
def __init__(
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
) -> None:
self.provider = provider
self.req = None
self.event = event
self.pipeline_ctx = pipeline_ctx
self._state = AgentState.IDLE
self.final_llm_resp = None
self.streaming = False
@override
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
self.req = req
self.streaming = streaming
self.final_llm_resp = None
self._state = AgentState.IDLE
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__)
async for resp in stream: # type: ignore
yield resp
else:
yield await self.provider.text_chat(**self.req.__dict__)
@override
async def step(self):
"""
Process a single step of the agent.
This method should return the result of the step.
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text)
),
)
continue
llm_resp_result = llm_response
break # got final response
if not llm_resp_result:
return
# 处理 LLM 响应
llm_resp = llm_resp_result
if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.ERROR)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
)
),
)
if not llm_resp.tools_call_name:
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
# 执行事件钩子
if await self.pipeline_ctx.call_event_hook(
self.event, EventType.OnLLMResponseEvent, llm_resp
):
return
# 返回 LLM 结果
if llm_resp.result_chain:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=llm_resp.result_chain),
)
elif llm_resp.completion_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text)
),
)
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
yield AgentResponse(
type="tool_call_result",
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
role="assistant",
tool_calls=llm_resp.to_openai_tool_calls(),
content=llm_resp.completion_text,
),
tool_calls_result=tool_call_result_blocks,
)
self.req.append_tool_calls_result(tool_calls_result)
async def _handle_function_tools(
self,
req: ProviderRequest,
llm_response: LLMResponse,
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
"""处理函数工具调用。"""
tool_call_result_blocks: list[ToolCallMessageSegment] = []
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
# 执行函数调用
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if not res:
continue
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
else:
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
# 尝试调用工具函数
wrapper = self.pipeline_ctx.call_handler(
self.event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None:
# Tool 返回结果
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
yield MessageChain().message(resp)
else:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
self.event.clear_result()
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
# 处理函数调用响应
if tool_call_result_blocks:
yield tool_call_result_blocks
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp

View File

@@ -2,56 +2,47 @@
本地 Agent 模式的 LLM 调用 Stage
"""
import traceback
import asyncio
import copy
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
MessageChain,
)
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from mcp.types import (
TextContent,
ImageContent,
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
ProviderRequest,
)
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext
from ..agent_runner.tool_loop_agent import ToolLoopAgent
from ..stage import Stage
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
"wake_prefix"
] # str
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
self.dequeue_context_length = min(
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
self.provider_wake_prefix: str = settings["wake_prefix"] # str
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
) # int
self.streaming_response = ctx.astrbot_config["provider_settings"][
"streaming_response"
] # bool
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 10)
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -62,12 +53,33 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
if sel_provider and isinstance(sel_provider, str):
provider = _ctx.get_provider_by_id(sel_provider)
if not provider:
logger.error(f"未找到指定的提供商: {sel_provider}")
return provider
return _ctx.get_using_provider(umo=event.unified_msg_origin)
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
req: ProviderRequest | None = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
@@ -78,13 +90,12 @@ class LLMRequestSubStage(Stage):
)
if req.conversation:
all_contexts = json.loads(req.conversation.history)
req.contexts = self._process_tool_message_pairs(
all_contexts, remove_tags=True
)
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest(prompt="", image_urls=[])
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
@@ -122,25 +133,7 @@ class LLMRequestSubStage(Stage):
return
# 执行请求 LLM 前事件钩子。
# 装饰 system_prompt 等功能
# 获取当前平台ID
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMRequestEvent, platform_id=platform_id
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if isinstance(req.contexts, str):
@@ -171,77 +164,77 @@ class LLMRequestSubStage(Stage):
if not req.session_id:
req.session_id = event.unified_msg_origin
async def requesting(req: ProviderRequest):
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
# fix messages
req.contexts = self.fix_messages(req.contexts)
final_llm_response = None
if self.streaming_response:
stream = provider.text_chat_stream(**req.__dict__)
async for llm_response in stream:
if llm_response.is_chunk:
if llm_response.result_chain:
yield llm_response.result_chain # MessageChain
else:
yield MessageChain().message(
llm_response.completion_text
# Call Agent
tool_loop_agent = ToolLoopAgent(
provider=provider,
event=event,
pipeline_ctx=self.ctx,
)
else:
final_llm_response = llm_response
else:
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM
if not final_llm_response:
raise Exception("LLM response is None.")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
async def requesting():
step_idx = 0
while step_idx < self.max_step:
step_idx += 1
try:
async for resp in tool_loop_agent.step():
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if (
self.show_tool_use
or event.get_platform_name() == "webchat"
):
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
continue
if not self.streaming_response:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if tool_loop_agent.done():
break
except Exception as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
return
if self.streaming_response:
# 流式输出的处理
async for result in self._handle_llm_stream_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
else:
# 非流式输出的处理
async for result in self._handle_llm_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
@@ -250,357 +243,127 @@ class LLMRequestSubStage(Stage):
)
)
# 保存到历史记录
await self._save_to_history(event, req, final_llm_response)
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
)
if not self.streaming_response:
event.set_extra("tool_call_result", None)
async for _ in requesting(req):
yield
else:
if self.streaming_response:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting(req))
.set_async_stream(requesting())
)
# 这里使用yield来暂停当前阶段等待流式输出完成后继续处理
yield
if event.get_extra("tool_call_result"):
event.set_result(event.get_extra("tool_call_result"))
event.set_extra("tool_call_result", None)
yield
# 暂时直接发出去
if img_b64 := event.get_extra("tool_call_img_respond"):
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
event.set_extra("tool_call_img_respond", None)
yield
async def _handle_llm_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理非流式 LLM 响应。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
if tool_loop_agent.done():
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
else:
chain = final_llm_resp.result_chain.chain
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.LLM_RESULT)
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async for _ in requesting():
yield
async def _handle_llm_stream_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理流式 LLM 响应。
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.STREAMING_FINISH)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.STREAMING_FINISH)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_function_tools(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理函数工具调用。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
"""
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if res:
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
if isinstance(res.content[0], TextContent):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
if conversation and not req.conversation.title:
messages = json.loads(conversation.history)
latest_pair = messages[-2:]
if not latest_pair:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
prompt=(
f"Please summarize the following query of user:\n"
f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `<None>`"
),
)
)
elif isinstance(res.content[0], ImageContent):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
event.set_extra(
"tool_call_img_respond",
res.content[0].data,
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
event.set_extra(
"tool_call_img_respond",
res.content[0].data,
)
else:
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
else:
# 获取处理器,过滤掉平台不兼容的处理器
platform_id = event.get_platform_id()
star_md = star_map.get(func_tool.handler_module_path)
if (
star_md
and platform_id in star_md.supported_platforms
and not star_md.supported_platforms[platform_id]
):
if llm_resp and llm_resp.completion_text:
logger.debug(
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
)
# 直接跳过不添加任何消息到tool_call_result
continue
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
title = llm_resp.completion_text.strip()
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
event.unified_msg_origin, title=title
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
else:
res = event.get_result()
if res and res.chain:
event.set_extra("tool_call_result", res)
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
# webchat adapter 中session_id 的格式是 f"webchat!{username}!{cid}"
# TODO: 优化 WebChat 适配器的对话管理
if event.session_id:
username, cid = event.session_id.split("!")[1:3]
db_helper = self.ctx.plugin_manager.context._db
db_helper.update_conversation_title(
user_id=username,
cid=cid,
title=title,
)
async def _save_to_history(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse | None,
):
if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
if not req or not req.conversation or not llm_response:
return
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts.copy()
contexts.append(await req.assemble_context())
# 记录并标记函数调用结果
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
messages.append(await req.assemble_context())
# 这一轮对话的 LLM 响应
if req.tool_calls_result:
tool_calls_messages = req.tool_calls_result.to_openai_messages()
# 添加标记
for message in tool_calls_messages:
message["_tool_call_history"] = True
processed_tool_messages = self._process_tool_message_pairs(
tool_calls_messages, remove_tags=False
)
contexts.extend(processed_tool_messages)
contexts.append(
{"role": "assistant", "content": llm_response.completion_text}
)
contexts_to_save = list(
filter(lambda item: "_no_save" not in item, contexts)
)
if not isinstance(req.tool_calls_result, list):
messages.extend(req.tool_calls_result.to_openai_messages())
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
event.unified_msg_origin, req.conversation.cid, history=messages
)
def _process_tool_message_pairs(self, messages, remove_tags=True):
"""处理工具调用消息确保assistant和tool消息成对出现
Args:
messages (list): 消息列表
remove_tags (bool): 是否移除_tool_call_history标记
Returns:
list: 处理后的消息列表保证了assistant和对应tool消息的成对出现
"""
result = []
i = 0
while i < len(messages):
current_msg = messages[i]
# 普通消息直接添加
if "_tool_call_history" not in current_msg:
result.append(current_msg.copy() if remove_tags else current_msg)
i += 1
continue
# 工具调用消息成对处理
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
assistant_msg = current_msg.copy()
if remove_tags and "_tool_call_history" in assistant_msg:
del assistant_msg["_tool_call_history"]
related_tools = []
j = i + 1
while (
j < len(messages)
and messages[j].get("role") == "tool"
and "_tool_call_history" in messages[j]
):
tool_msg = messages[j].copy()
if remove_tags:
del tool_msg["_tool_call_history"]
related_tools.append(tool_msg)
j += 1
# 成对的时候添加到结果
if related_tools:
result.append(assistant_msg)
result.extend(related_tools)
i = j # 跳过已处理
def fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
# 单独的tool消息
i += 1
return result
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages

View File

@@ -50,7 +50,7 @@ class StarRequestSubStage(Stage):
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
wrapper = self.ctx.call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果

View File

@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping
from astrbot.core.utils.session_lock import session_lock_manager
@register_stage
@@ -29,11 +30,10 @@ class RespondStage(Stage):
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Node: lambda comp: bool(comp.name)
and comp.uin != 0
and bool(comp.content), # 一个转发节点
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
}
async def initialize(self, ctx: PipelineContext):
@@ -129,9 +129,7 @@ class RespondStage(Stage):
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send()
await event.send_streaming(result.async_stream, use_fallback)
await event._post_send()
return
elif len(result.chain) > 0:
# 检查路径映射
@@ -142,8 +140,6 @@ class RespondStage(Stage):
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
await event._pre_send()
# 检查消息链是否为空
try:
if await self._is_empty_message_chain(result.chain):
@@ -159,9 +155,14 @@ class RespondStage(Stage):
c for c in result.chain if not isinstance(c, Comp.Record)
]
if self.enable_seg and (
if (
self.enable_seg
and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
)
and event.get_platform_name()
not in ["qq_official", "weixin_official_account", "dingtalk"]
):
decorated_comps = []
if self.reply_with_mention:
@@ -177,6 +178,8 @@ class RespondStage(Stage):
result.chain.remove(comp)
break
# leverage lock to guarentee the order of message sending among different events
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
@@ -185,13 +188,13 @@ class RespondStage(Stage):
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复
for comp in non_record_comps:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
@@ -208,7 +211,6 @@ class RespondStage(Stage):
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
await event._post_send()
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)

View File

@@ -1,17 +1,19 @@
import time
import re
import time
import traceback
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from typing import AsyncGenerator, Union
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage, registered_stages
@register_stage
@@ -140,7 +142,11 @@ class ResultDecorateStage(Stage):
break
# 分段回复
if self.enable_segmented_reply:
if self.enable_segmented_reply and event.get_platform_name() not in [
"qq_official",
"weixin_official_account",
"dingtalk",
]:
if (
self.only_llm_result and result.is_llm_result()
) or not self.only_llm_result:
@@ -168,30 +174,57 @@ class ResultDecorateStage(Stage):
result.chain = new_chain
# TTS
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
event.unified_msg_origin
)
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
and SessionServiceManager.should_process_tts_request(event)
):
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + comp.text)
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(
Record(file=audio_path, url=audio_path)
)
if(self.ctx.astrbot_config["provider_tts_settings"]["dual_output"]):
new_chain.append(comp)
else:
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件找到,消息段转语音失败: {comp.text}"
f"由于 TTS 音频文件找到,消息段转语音失败: {comp.text}"
)
new_chain.append(comp)
except BaseException:
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
@@ -225,6 +258,14 @@ class ResultDecorateStage(Stage):
if url:
if url.startswith("http"):
result.chain = [Image.fromURL(url)]
elif (
self.ctx.astrbot_config["t2i_use_file_service"]
and self.ctx.astrbot_config["callback_api_base"]
):
token = await file_token_service.register_file(url)
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
logger.debug(f"已注册:{url}")
result.chain = [Image.fromURL(url)]
else:
result.chain = [Image.fromFileSystem(url)]

View File

@@ -73,7 +73,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if not event._has_send_oper and event.get_platform_name() == "webchat":
if event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -0,0 +1,22 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core import logger
@register_stage
class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
event.stop_event()

View File

@@ -1,12 +1,8 @@
from __future__ import annotations
import abc
import inspect
import traceback
from astrbot.api import logger
from typing import List, AsyncGenerator, Union, Awaitable
from typing import List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
@@ -41,70 +37,3 @@ class Stage(abc.ABC):
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
"""
raise NotImplementedError
async def _call_handler(
self,
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型每次yield都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 待处理的事件对象
handler (Awaitable): 事件处理函数
*args: 传递给handler的位置参数
**kwargs: 传递给handler的关键字参数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器(async def)
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
# 如果是一个异步生成器, 进入洋葱模型
_has_yielded = False # 是否返回过值
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield # 传递控制权给上一层的process函数
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret # 传递控制权给上一层的process函数
if not _has_yielded:
# 如果这个异步生成器没有执行到yield分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield # 传递控制权给上一层的process函数
else:
yield ret # 传递控制权给上一层的process函数

View File

@@ -1,13 +1,16 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot import logger
from typing import Union, AsyncGenerator
from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage
@register_stage
@@ -39,6 +42,9 @@ class WakingCheckStage(Stage):
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
"ignore_bot_self_message", False
)
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
"ignore_at_all", False
)
async def process(
self, event: AstrMessageEvent
@@ -77,11 +83,18 @@ class WakingCheckStage(Stage):
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
if not is_wake:
# 检查是否有 at 消息
# 检查是否有at消息 / at全体成员消息 / 引用了bot的消息
for message in messages:
if isinstance(message, At) and (
str(message.qq) == str(event.get_self_id())
or str(message.qq) == "all"
if (
(
isinstance(message, At)
and (str(message.qq) == str(event.get_self_id()))
)
or (isinstance(message, AtAll) and not self.ignore_at_all)
or (
isinstance(message, Reply)
and str(message.sender_id) == str(event.get_self_id())
)
):
is_wake = True
event.is_wake = True
@@ -125,7 +138,6 @@ class WakingCheckStage(Stage):
f"插件 {star_map[handler.handler_module_path].name}: {e}"
)
)
await event._post_send()
event.stop_event()
passed = False
break
@@ -140,7 +152,6 @@ class WakingCheckStage(Stage):
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
)
)
await event._post_send()
logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
)
@@ -156,7 +167,12 @@ class WakingCheckStage(Stage):
"parsed_params"
)
event.clear_extra()
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
event, activated_handlers
)
event.set_extra("activated_handlers", activated_handlers)
event.set_extra("handlers_parsed_params", handlers_parsed_params)

View File

@@ -227,7 +227,7 @@ class AstrMessageEvent(abc.ABC):
):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊。
Fallback仅支持 aiocqhttp, gewechat
Fallback仅支持 aiocqhttp。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
@@ -235,10 +235,10 @@ class AstrMessageEvent(abc.ABC):
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法"""
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
async def _post_send(self):
"""调度器会在执行 send() 后调用该方法"""
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
def set_result(self, result: Union[MessageEventResult, str]):
"""设置消息事件的结果。
@@ -419,7 +419,6 @@ class AstrMessageEvent(abc.ABC):
适配情况:
- gewechat
- aiocqhttp(OneBotv11)
"""
...

View File

@@ -58,10 +58,6 @@ class PlatformManager:
from .sources.qqofficial_webhook.qo_webhook_adapter import (
QQOfficialWebhookPlatformAdapter, # noqa: F401
)
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import (
GewechatPlatformAdapter, # noqa: F401
)
case "wechatpadpro":
from .sources.wechatpadpro.wechatpadpro_adapter import (
WeChatPadProAdapter, # noqa: F401
@@ -77,7 +73,15 @@ class PlatformManager:
case "wecom":
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
from .sources.weixin_official_account.weixin_offacc_adapter import (
WeixinOfficialAccountPlatformAdapter, # noqa
)
case "discord":
from .sources.discord.discord_platform_adapter import (
DiscordPlatformAdapter, # noqa: F401
)
case "slack":
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"

View File

@@ -1,11 +1,19 @@
import asyncio
import re
from typing import AsyncGenerator, Dict, List
from aiocqhttp import CQHttp
from aiocqhttp import CQHttp, Event
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File
from astrbot.api.message_components import (
Image,
Node,
Nodes,
Plain,
Record,
Video,
File,
BaseMessageComponent,
)
from astrbot.api.platform import Group, MessageMember
from astrbot.core import file_token_service, astrbot_config, logger
class AiocqhttpMessageEvent(AstrMessageEvent):
@@ -15,88 +23,120 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@staticmethod
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
"""修复部分字段"""
if isinstance(segment, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await segment.convert_to_base64()
return {
"type": segment.type.lower(),
"data": {
"file": f"base64://{bs64}",
},
}
elif isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
return d
elif isinstance(segment, Video):
d = await segment.to_dict()
return d
else:
# For other segments, we simply convert them to a dict by calling toDict
return segment.toDict()
@staticmethod
async def _parse_onebot_json(message_chain: MessageChain):
"""解析成 OneBot json 格式"""
ret = []
for segment in message_chain.chain:
d = segment.toDict()
if isinstance(segment, Plain):
d["type"] = "text"
d["data"]["text"] = segment.text.strip()
# 如果是空文本或者只带换行符的文本,不发送
if not d["data"]["text"]:
if not segment.text.strip():
continue
elif isinstance(segment, (Image, Record)):
# convert to base64
bs64 = await segment.convert_to_base64()
d["data"] = {
"file": f"base64://{bs64}",
}
elif isinstance(segment, At):
d["data"] = {
"qq": str(segment.qq), # 转换为字符串
}
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
ret.append(d)
return ret
async def send(self, message: MessageChain):
@classmethod
async def _dispatch_send(
cls,
bot: CQHttp,
event: Event | None,
is_group: bool,
session_id: str,
messages: list[dict],
):
if event:
await bot.send(event=event, message=messages)
elif is_group:
await bot.send_group_msg(group_id=session_id, message=messages)
else:
await bot.send_private_msg(user_id=session_id, message=messages)
@classmethod
async def send_message(
cls,
bot: CQHttp,
message_chain: MessageChain,
event: Event | None = None,
is_group: bool = False,
session_id: str = None,
):
"""发送消息"""
# 转发消息、文件消息不能和普通消息混在一起发送
send_one_by_one = any(
isinstance(seg, (Node, Nodes, File)) for seg in message.chain
isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain
)
if send_one_by_one:
for seg in message.chain:
if not send_one_by_one:
ret = await cls._parse_onebot_json(message_chain)
if not ret:
return
await cls._dispatch_send(bot, event, is_group, session_id, ret)
return
for seg in message_chain.chain:
if isinstance(seg, (Node, Nodes)):
# 合并转发消息
if isinstance(seg, Node):
nodes = Nodes([seg])
seg = nodes
payload = seg.toDict()
if self.get_group_id():
payload["group_id"] = self.get_group_id()
await self.bot.call_action("send_group_forward_msg", **payload)
else:
payload["user_id"] = self.get_sender_id()
await self.bot.call_action(
"send_private_forward_msg", **payload
)
elif isinstance(seg, File):
d = seg.toDict()
url_or_path = await seg.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
d["data"] = {
"name": seg.name,
"file": payload_file,
}
await self.bot.send(
self.message_obj.raw_message,
[d],
)
else:
await self.bot.send(
self.message_obj.raw_message,
await AiocqhttpMessageEvent._parse_onebot_json(
MessageChain([seg])
),
)
await asyncio.sleep(0.5)
else:
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if not ret:
return
await self.bot.send(self.message_obj.raw_message, ret)
payload = await seg.to_dict()
if is_group:
payload["group_id"] = session_id
await bot.call_action("send_group_forward_msg", **payload)
else:
payload["user_id"] = session_id
await bot.call_action("send_private_forward_msg", **payload)
elif isinstance(seg, File):
d = await cls._from_segment_to_dict(seg)
await cls._dispatch_send(bot, event, is_group, session_id, [d])
else:
messages = await cls._parse_onebot_json(MessageChain([seg]))
if not messages:
continue
await cls._dispatch_send(bot, event, is_group, session_id, messages)
await asyncio.sleep(0.5)
async def send(self, message: MessageChain):
"""发送消息"""
event = self.message_obj.raw_message
assert isinstance(event, Event), "Event must be an instance of aiocqhttp.Event"
is_group = False
if self.get_group_id():
is_group = True
session_id = self.get_group_id()
else:
session_id = self.get_sender_id()
await self.send_message(
bot=self.bot,
message_chain=message,
event=event,
is_group=is_group,
session_id=session_id,
)
await super().send(message)
async def send_streaming(

View File

@@ -83,19 +83,18 @@ class AiocqhttpAdapter(Platform):
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain)
match session.message_type.value:
case MessageType.GROUP_MESSAGE.value:
if "_" in session.session_id:
# 独立会话
_, group_id = session.session_id.split("_")
await self.bot.send_group_msg(group_id=group_id, message=ret)
is_group = session.message_type == MessageType.GROUP_MESSAGE
if is_group:
session_id = session.session_id.split("_")[-1]
else:
await self.bot.send_group_msg(
group_id=session.session_id, message=ret
session_id = session.session_id
await AiocqhttpMessageEvent.send_message(
bot=self.bot,
message_chain=message_chain,
event=None, # 这里不需要 event因为是通过 session 发送的
is_group=is_group,
session_id=session_id,
)
case MessageType.FRIEND_MESSAGE.value:
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
async def convert_message(self, event: Event) -> AstrBotMessage:
@@ -168,9 +167,7 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(
Poke(qq=str(event["target_id"]), type="poke")
) # noqa: F405
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
return abm
@@ -221,6 +218,9 @@ class AiocqhttpAdapter(Platform):
a = None
if t == "text":
current_text = "".join(m["data"]["text"] for m in m_group).strip()
if not current_text:
# 如果文本段为空,则跳过
continue
message_str += current_text
a = ComponentTypes[t](text=current_text) # noqa: F405
abm.message.append(a)
@@ -270,8 +270,16 @@ class AiocqhttpAdapter(Platform):
action="get_msg",
message_id=int(m["data"]["id"]),
)
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
reply_event_data["post_type"] = "message"
new_event = Event.from_payload(reply_event_data)
if not new_event:
logger.error(
f"无法从回复消息数据构造 Event 对象: {reply_event_data}"
)
continue
abm_reply = await self._convert_handle_message_event(
Event.from_payload(reply_event_data), get_reply=False
new_event, get_reply=False
)
reply_seg = Reply(
@@ -304,7 +312,9 @@ class AiocqhttpAdapter(Platform):
user_id=int(m["data"]["qq"]),
)
if at_info:
nickname = at_info.get("nick", "")
nickname = at_info.get("nick", "") or at_info.get(
"nickname", ""
)
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
abm.message.append(
@@ -319,7 +329,7 @@ class AiocqhttpAdapter(Platform):
first_at_self_processed = True
else:
# 非第一个@机器人或@其他用户添加到message_str
message_str += f" @{nickname} "
message_str += f" @{nickname}({m['data']['qq']}) "
else:
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
except ActionFailed as e:

View File

@@ -32,21 +32,17 @@ class DingtalkMessageEvent(AstrMessageEvent):
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
if segment.file and segment.file.startswith("file:///"):
logger.warning(
"dingtalk only support url image, not: " + segment.file
)
continue
elif segment.file and segment.file.startswith("http"):
markdown_str += f"![image]({segment.file})\n\n"
elif segment.file and segment.file.startswith("base64://"):
logger.warning("dingtalk only support url image, not base64")
try:
if not segment.file:
logger.warning("钉钉图片 segment 缺少 file 字段,跳过")
continue
if segment.file.startswith(("http://", "https://")):
image_url = segment.file
else:
logger.warning(
"dingtalk only support url image, not: " + segment.file
)
continue
image_url = await segment.register_to_file_service()
markdown_str = f"![image]({image_url})\n\n"
ret = await asyncio.get_event_loop().run_in_executor(
None,
@@ -57,6 +53,11 @@ class DingtalkMessageEvent(AstrMessageEvent):
)
logger.debug(f"send image: {ret}")
except Exception as e:
logger.error(f"钉钉图片处理失败: {e}")
logger.warning(f"跳过图片发送: {image_path}")
continue
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)

View File

@@ -0,0 +1,126 @@
import discord
from astrbot import logger
import sys
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# Discord Bot客户端
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
def __init__(self, token: str, proxy: str = None):
self.token = token
self.proxy = proxy
# 设置Intent权限遵循权限最小化原则
intents = discord.Intents.default()
intents.message_content = True # 订阅消息内容事件 (Privileged)
intents.members = True # 订阅成员事件 (Privileged)
# 初始化Bot
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
self.on_ready_once_callback = None
self._ready_once_fired = False
@override
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
if self.on_ready_once_callback and not self._ready_once_fired:
self._ready_once_fired = True
try:
await self.on_ready_once_callback()
except Exception as e:
logger.error(
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
)
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
is_mentioned = self.user in message.mentions
return {
"message": message,
"bot_id": str(self.user.id),
"content": message.content,
"username": message.author.display_name,
"userid": str(message.author.id),
"message_id": str(message.id),
"channel_id": str(message.channel.id),
"guild_id": str(message.guild.id) if message.guild else None,
"type": "message",
"is_mentioned": is_mentioned,
"clean_content": message.clean_content,
}
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
return {
"interaction": interaction,
"bot_id": str(self.user.id),
"content": self._extract_interaction_content(interaction),
"username": interaction.user.display_name,
"userid": str(interaction.user.id),
"message_id": str(interaction.id),
"channel_id": str(interaction.channel_id)
if interaction.channel_id
else None,
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"type": "interaction",
}
@override
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
return
logger.debug(
f"[Discord] 收到原始消息 from {message.author.name}: {message.content}"
)
if self.on_message_received:
message_data = self._create_message_data(message)
await self.on_message_received(message_data)
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
"""从交互中提取内容"""
interaction_type = interaction.type
interaction_data = getattr(interaction, "data", {})
if not interaction_data:
return ""
if interaction_type == discord.InteractionType.application_command:
command_name = interaction_data.get("name", "")
if options := interaction_data.get("options", []):
params = " ".join(
[f"{opt['name']}:{opt.get('value', '')}" for opt in options]
)
return f"/{command_name} {params}"
return f"/{command_name}"
elif interaction_type == discord.InteractionType.component:
custom_id = interaction_data.get("custom_id", "")
component_type = interaction_data.get("component_type", "")
return f"component:{custom_id}:{component_type}"
return str(interaction_data)
async def start_polling(self):
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
@override
async def close(self):
"""关闭客户端"""
if not self.is_closed():
await super().close()

View File

@@ -0,0 +1,135 @@
import discord
from typing import List
from astrbot.api.message_components import BaseMessageComponent
# Discord专用组件
class DiscordEmbed(BaseMessageComponent):
"""Discord Embed消息组件"""
type: str = "discord_embed"
def __init__(
self,
title: str = None,
description: str = None,
color: int = None,
url: str = None,
thumbnail: str = None,
image: str = None,
footer: str = None,
fields: List[dict] = None,
):
self.title = title
self.description = description
self.color = color
self.url = url
self.thumbnail = thumbnail
self.image = image
self.footer = footer
self.fields = fields or []
def to_discord_embed(self) -> discord.Embed:
"""转换为Discord Embed对象"""
embed = discord.Embed()
if self.title:
embed.title = self.title
if self.description:
embed.description = self.description
if self.color:
embed.color = self.color
if self.url:
embed.url = self.url
if self.thumbnail:
embed.set_thumbnail(url=self.thumbnail)
if self.image:
embed.set_image(url=self.image)
if self.footer:
embed.set_footer(text=self.footer)
for field in self.fields:
embed.add_field(
name=field.get("name", ""),
value=field.get("value", ""),
inline=field.get("inline", False),
)
return embed
class DiscordButton(BaseMessageComponent):
"""Discord按钮组件"""
type: str = "discord_button"
def __init__(
self,
label: str,
custom_id: str = None,
style: str = "primary",
emoji: str = None,
url: str = None,
disabled: bool = False,
):
self.label = label
self.custom_id = custom_id
self.style = style
self.emoji = emoji
self.url = url
self.disabled = disabled
class DiscordReference(BaseMessageComponent):
"""Discord引用组件"""
type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str):
self.message_id = message_id
self.channel_id = channel_id
class DiscordView(BaseMessageComponent):
"""Discord视图组件包含按钮和选择菜单"""
type: str = "discord_view"
def __init__(
self, components: List[BaseMessageComponent] = None, timeout: float = None
):
self.components = components or []
self.timeout = timeout
def to_discord_view(self) -> discord.ui.View:
"""转换为Discord View对象"""
view = discord.ui.View(timeout=self.timeout)
for component in self.components:
if isinstance(component, DiscordButton):
button_style = getattr(
discord.ButtonStyle, component.style, discord.ButtonStyle.primary
)
if component.url:
# URL按钮
button = discord.ui.Button(
label=component.label,
style=discord.ButtonStyle.link,
url=component.url,
emoji=component.emoji,
disabled=component.disabled,
)
else:
# 普通按钮
button = discord.ui.Button(
label=component.label,
style=button_style,
custom_id=component.custom_id,
emoji=component.emoji,
disabled=component.disabled,
)
view.add_item(button)
return view

View File

@@ -0,0 +1,455 @@
import asyncio
import discord
import sys
import re
from discord.abc import Messageable
from discord.channel import DMChannel
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
PlatformMetadata,
MessageType,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, File
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
from astrbot import logger
from .client import DiscordBotClient
from .discord_platform_event import DiscordPlatformEvent
from typing import Any, Tuple
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# 注册平台适配器
@register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
class DiscordPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.client_self_id = None
self.registered_handlers = []
# 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True)
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
self.activity_name = self.config.get("discord_activity_name", None)
self.shutdown_event = asyncio.Event()
self._polling_task = None
@override
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
"""通过会话发送消息"""
# 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage()
if "_" in session.session_id:
session.session_id = session.session_id.split("_")[1]
channel_id_str = session.session_id
channel = None
try:
channel_id = int(channel_id_str)
channel = self.client.get_channel(channel_id)
except (ValueError, TypeError):
logger.warning(f"[Discord] Invalid channel ID format: {channel_id_str}")
if channel:
message_obj.type = self._get_message_type(channel)
message_obj.group_id = self._get_channel_id(channel)
else:
logger.warning(
f"[Discord] Can't get channel info for {channel_id_str}, will guess message type."
)
message_obj.type = MessageType.GROUP_MESSAGE
message_obj.group_id = session.session_id
message_obj.message_str = message_chain.get_plain_text()
message_obj.sender = MessageMember(
user_id=str(self.client_self_id), nickname=self.client.user.display_name
)
message_obj.self_id = self.client_self_id
message_obj.session_id = session.session_id
message_obj.message = message_chain
# 创建临时事件对象来发送消息
temp_event = DiscordPlatformEvent(
message_str=message_chain.get_plain_text(),
message_obj=message_obj,
platform_meta=self.meta(),
session_id=session.session_id,
client=self.client,
)
await temp_event.send(message_chain)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
"""返回平台元数据"""
return PlatformMetadata(
"discord",
"Discord 适配器",
id=self.config.get("id"),
default_config_tmpl=self.config,
)
@override
async def run(self):
"""主要运行逻辑"""
# 初始化回调函数
async def on_received(message_data):
logger.debug(f"[Discord] 收到消息: {message_data}")
if self.client_self_id is None:
self.client_self_id = message_data.get("bot_id")
abm = await self.convert_message(data=message_data)
await self.handle_msg(abm)
# 初始化 Discord 客户端
token = str(self.config.get("discord_token"))
if not token:
logger.error("[Discord] Bot Token 未配置。请在配置文件中正确设置 token。")
return
proxy = self.config.get("discord_proxy") or None
self.client = DiscordBotClient(token, proxy)
self.client.on_message_received = on_received
async def callback():
if self.enable_command_register:
await self._collect_and_register_commands()
if self.activity_name:
await self.client.change_presence(
status=discord.Status.online,
activity=discord.CustomActivity(name=self.activity_name),
)
self.client.on_ready_once_callback = callback
try:
self._polling_task = asyncio.create_task(self.client.start_polling())
await self.shutdown_event.wait()
except discord.errors.LoginFailure:
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
except discord.errors.ConnectionClosed:
logger.warning("[Discord] 与 Discord 的连接已关闭。")
except Exception as e:
logger.error(f"[Discord] 适配器运行时发生意外错误: {e}", exc_info=True)
def _get_message_type(
self, channel: Messageable, guild_id: int | None = None
) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型"""
if guild_id is not None:
return MessageType.GROUP_MESSAGE
if isinstance(channel, DMChannel) or getattr(channel, "guild", None) is None:
return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE
def _get_channel_id(self, channel: Messageable) -> str:
"""根据 channel 对象获取ID"""
return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"]
content = message.content
# 如果机器人被@,移除@部分
# 剥离 User Mention (<@id>, <@!id>)
if self.client and self.client.user:
mention_str = f"<@{self.client.user.id}>"
mention_str_nickname = f"<@!{self.client.user.id}>"
if content.startswith(mention_str):
content = content[len(mention_str) :].lstrip()
elif content.startswith(mention_str_nickname):
content = content[len(mention_str_nickname) :].lstrip()
# 剥离 Role Mentionbot 拥有的任一角色被提及,<@&role_id>
if (
hasattr(message, "role_mentions")
and hasattr(message, "guild")
and message.guild
):
bot_member = (
message.guild.get_member(self.client.user.id)
if self.client and self.client.user
else None
)
if bot_member and hasattr(bot_member, "roles"):
for role in bot_member.roles:
role_mention_str = f"<@&{role.id}>"
if content.startswith(role_mention_str):
content = content[len(role_mention_str) :].lstrip()
break # 只剥离第一个匹配的角色 mention
abm = AstrBotMessage()
abm.type = self._get_message_type(message.channel)
abm.group_id = self._get_channel_id(message.channel)
abm.message_str = content
abm.sender = MessageMember(
user_id=str(message.author.id), nickname=message.author.display_name
)
message_chain = []
if abm.message_str:
message_chain.append(Plain(text=abm.message_str))
if message.attachments:
for attachment in message.attachments:
if attachment.content_type and attachment.content_type.startswith(
"image/"
):
message_chain.append(
Image(file=attachment.url, filename=attachment.filename)
)
else:
message_chain.append(
File(name=attachment.filename, url=attachment.url)
)
abm.message = message_chain
abm.raw_message = message
abm.self_id = self.client_self_id
abm.session_id = str(message.channel.id)
abm.message_id = str(message.id)
return abm
async def convert_message(self, data: dict) -> AstrBotMessage:
"""将平台消息转换成 AstrBotMessage"""
# 由于 on_interaction 已被禁用,我们只处理普通消息
return self._convert_message_to_abm(data)
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None):
"""处理消息"""
message_event = DiscordPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
interaction_followup_webhook=followup_webhook,
)
# 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False
# User Mention
if (
self.client
and self.client.user
and hasattr(message.raw_message, "mentions")
):
if self.client.user in message.raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"):
bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
try:
bot_member = message.raw_message.guild.get_member(
self.client.user.id
)
except Exception:
bot_member = None
if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions)
if (
bot_roles
and mentioned_roles
and bot_roles.intersection(mentioned_roles)
):
is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态
if is_slash_command or is_mention:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
@override
async def terminate(self):
"""终止适配器"""
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
self.shutdown_event.set()
# 优先 cancel polling_task
if self._polling_task:
self._polling_task.cancel()
try:
await asyncio.wait_for(self._polling_task, timeout=10)
except asyncio.CancelledError:
logger.info("[Discord] polling_task 已取消。")
except Exception as e:
logger.warning(f"[Discord] polling_task 取消异常: {e}")
logger.info("[Discord] 正在清理已注册的斜杠指令... (step 2)")
# 清理指令
if self.enable_command_register and self.client:
try:
await asyncio.wait_for(
self.client.sync_commands(
commands=[],
guild_ids=[self.guild_id] if self.guild_id else None,
),
timeout=10,
)
logger.info("[Discord] 指令清理完成。")
except Exception as e:
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
logger.info("[Discord] 正在关闭 Discord 客户端... (step 3)")
if self.client and hasattr(self.client, "close"):
try:
await asyncio.wait_for(self.client.close(), timeout=10)
except Exception as e:
logger.warning(f"[Discord] 客户端关闭异常: {e}")
logger.info("[Discord] 适配器已终止。")
def register_handler(self, handler_info):
"""注册处理器信息"""
self.registered_handlers.append(handler_info)
async def _collect_and_register_commands(self):
"""收集所有指令并注册到Discord"""
logger.info("[Discord] 开始收集并注册斜杠指令...")
registered_commands = []
for handler_md in star_handlers_registry:
if not star_map[handler_md.handler_module_path].activated:
continue
for event_filter in handler_md.event_filters:
cmd_info = self._extract_command_info(event_filter, handler_md)
if not cmd_info:
continue
cmd_name, description, cmd_filter_instance = cmd_info
# 创建动态回调
callback = self._create_dynamic_callback(cmd_name)
# 创建一个通用的参数选项来接收所有文本输入
options = [
discord.Option(
name="params",
description="指令的所有参数",
type=discord.SlashCommandOptionType.string,
required=False,
)
]
# 创建SlashCommand
slash_command = discord.SlashCommand(
name=cmd_name,
description=description,
func=callback,
options=options,
guild_ids=[self.guild_id] if self.guild_id else None,
)
self.client.add_application_command(slash_command)
registered_commands.append(cmd_name)
if registered_commands:
logger.info(
f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}"
)
else:
logger.info("[Discord] 没有发现可注册的指令。")
# 使用 Pycord 的方法同步指令
# 注意:这可能需要一些时间,并且有频率限制
await self.client.sync_commands()
logger.info("[Discord] 指令同步完成。")
def _create_dynamic_callback(self, cmd_name: str):
"""为每个指令动态创建一个异步回调函数"""
async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None):
# 将平台特定的前缀'/'剥离以适配通用的CommandFilter
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
logger.debug(f"[Discord] 回调函数参数: {ctx}")
logger.debug(f"[Discord] 回调函数参数: {params}")
message_str_for_filter = cmd_name
if params:
message_str_for_filter += f" {params}"
logger.debug(
f"[Discord] 斜杠指令 '{cmd_name}' 被触发。 "
f"原始参数: '{params}'. "
f"构建的指令字符串: '{message_str_for_filter}'"
)
# 尝试立即响应,防止超时
followup_webhook = None
try:
await ctx.defer()
followup_webhook = ctx.followup
except Exception as e:
logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}")
# 2. 构建 AstrBotMessage
abm = AstrBotMessage()
abm.type = self._get_message_type(ctx.channel, ctx.guild_id)
abm.group_id = self._get_channel_id(ctx.channel)
abm.message_str = message_str_for_filter
abm.sender = MessageMember(
user_id=str(ctx.author.id), nickname=ctx.author.display_name
)
abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id
abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id)
# 3. 将消息和 webhook 分别交给 handle_msg 处理
await self.handle_msg(abm, followup_webhook)
return dynamic_callback
@staticmethod
def _extract_command_info(
event_filter: Any, handler_metadata: StarHandlerMetadata
) -> Tuple[str, str, CommandFilter] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
# is_group = False
cmd_filter_instance = None
if isinstance(event_filter, CommandFilter):
# 暂不支持子指令注册为斜杠指令
if (
event_filter.parent_command_names
and event_filter.parent_command_names != [""]
):
return None
cmd_name = event_filter.command_name
cmd_filter_instance = event_filter
elif isinstance(event_filter, CommandGroupFilter):
# 暂不支持指令组直接注册为斜杠指令,因为它们没有 handle 方法
return None
if not cmd_name:
return None
# Discord 斜杠指令名称规范
if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name):
logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}")
return None
description = handler_metadata.desc or f"指令: {cmd_name}"
if len(description) > 100:
description = f"{description[:97]}..."
return cmd_name, description, cmd_filter_instance

View File

@@ -0,0 +1,296 @@
import asyncio
import discord
import base64
from io import BytesIO
from pathlib import Path
from typing import Optional
import sys
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, At
from astrbot.api.message_components import (
Plain,
Image,
File,
BaseMessageComponent,
Reply,
)
from astrbot import logger
from .client import DiscordBotClient
from .components import DiscordEmbed, DiscordView
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
# 自定义Discord视图组件兼容旧版本
class DiscordViewComponent(BaseMessageComponent):
type: str = "discord_view"
def __init__(self, view: discord.ui.View):
self.view = view
class DiscordPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: DiscordBotClient,
interaction_followup_webhook: Optional[discord.Webhook] = None,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.interaction_followup_webhook = interaction_followup_webhook
@override
async def send(self, message: MessageChain):
"""发送消息到Discord平台"""
# 解析消息链为 Discord 所需的对象
try:
(
content,
files,
view,
embeds,
reference_message_id,
) = await self._parse_to_discord(message)
except Exception as e:
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
return
kwargs = {}
if content:
kwargs["content"] = content
if files:
kwargs["files"] = files
if view:
kwargs["view"] = view
if embeds:
kwargs["embeds"] = embeds
if reference_message_id and not self.interaction_followup_webhook:
kwargs["reference"] = self.client.get_message(int(reference_message_id))
if not kwargs:
logger.debug("[Discord] 尝试发送空消息,已忽略。")
return
# 根据上下文执行发送/回复操作
try:
# -- 斜杠指令/交互上下文 --
if self.interaction_followup_webhook:
await self.interaction_followup_webhook.send(**kwargs)
# -- 常规消息上下文 --
else:
channel = await self._get_channel()
if not channel:
return
else:
await channel.send(**kwargs)
except Exception as e:
logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True)
await super().send(message)
async def _get_channel(self) -> Optional[discord.abc.Messageable]:
"""获取当前事件对应的频道对象"""
try:
channel_id = int(self.session_id)
return self.client.get_channel(
channel_id
) or await self.client.fetch_channel(channel_id)
except (ValueError, discord.errors.NotFound, discord.errors.Forbidden):
logger.error(f"[Discord] 无法获取频道 {self.session_id}")
return None
async def _parse_to_discord(
self,
message: MessageChain,
) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
content = ""
files = []
view = None
embeds = []
reference_message_id = None
for i in message.chain: # 遍历消息链
if isinstance(i, Plain): # 如果是文字类型的
content += i.text
elif isinstance(i, Reply):
reference_message_id = i.id
elif isinstance(i, At):
content += f"<@{i.qq}>"
elif isinstance(i, Image):
logger.debug(f"[Discord] 开始处理 Image 组件: {i}")
try:
filename = getattr(i, "filename", None)
file_content = getattr(i, "file", None)
if not file_content:
logger.warning(f"[Discord] Image 组件没有 file 属性: {i}")
continue
discord_file = None
# 1. URL
if file_content.startswith("http"):
logger.debug(f"[Discord] 处理 URL 图片: {file_content}")
embed = discord.Embed().set_image(url=file_content)
embeds.append(embed)
continue
# 2. File URI
elif file_content.startswith("file:///"):
logger.debug(f"[Discord] 处理 File URI: {file_content}")
path = Path(file_content[8:])
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
discord_file = discord.File(
BytesIO(file_bytes), filename=filename or path.name
)
else:
logger.warning(f"[Discord] 图片文件不存在: {path}")
# 3. Base64 URI
elif file_content.startswith("base64://"):
logger.debug("[Discord] 处理 Base64 URI")
b64_data = file_content.split("base64://", 1)[1]
missing_padding = len(b64_data) % 4
if missing_padding:
b64_data += "=" * (4 - missing_padding)
img_bytes = base64.b64decode(b64_data)
discord_file = discord.File(
BytesIO(img_bytes), filename=filename or "image.png"
)
# 4. 裸 Base64 或本地路径
else:
try:
logger.debug("[Discord] 尝试作为裸 Base64 处理")
b64_data = file_content
missing_padding = len(b64_data) % 4
if missing_padding:
b64_data += "=" * (4 - missing_padding)
img_bytes = base64.b64decode(b64_data)
discord_file = discord.File(
BytesIO(img_bytes), filename=filename or "image.png"
)
except (ValueError, TypeError, base64.binascii.Error):
logger.debug(
f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}"
)
path = Path(file_content)
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
discord_file = discord.File(
BytesIO(file_bytes), filename=filename or path.name
)
else:
logger.warning(f"[Discord] 图片文件不存在: {path}")
if discord_file:
files.append(discord_file)
except Exception:
# 使用 getattr 来安全地访问 i.file以防 i 本身就是问题
file_info = getattr(i, "file", "未知")
logger.error(
f"[Discord] 处理图片时发生未知严重错误: {file_info}",
exc_info=True,
)
elif isinstance(i, File):
try:
file_path_str = await i.get_file()
if file_path_str:
path = Path(file_path_str)
if await asyncio.to_thread(path.exists):
file_bytes = await asyncio.to_thread(path.read_bytes)
files.append(
discord.File(BytesIO(file_bytes), filename=i.name)
)
else:
logger.warning(
f"[Discord] 获取文件失败,路径不存在: {file_path_str}"
)
else:
logger.warning(f"[Discord] 获取文件失败: {i.name}")
except Exception as e:
logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}")
elif isinstance(i, DiscordEmbed):
# Discord Embed消息
embeds.append(i.to_discord_embed())
elif isinstance(i, DiscordView):
# Discord视图组件按钮、选择菜单等
view = i.to_discord_view()
elif isinstance(i, DiscordViewComponent):
# 如果消息链中包含Discord视图组件兼容旧版本
if isinstance(i.view, discord.ui.View):
view = i.view
else:
logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}")
if len(content) > 2000:
logger.warning("[Discord] 消息内容超过2000字符将被截断。")
content = content[:2000]
return content, files, view, embeds, reference_message_id
async def react(self, emoji: str):
"""对原消息添加反应"""
try:
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "add_reaction"
):
await self.message_obj.raw_message.add_reaction(emoji)
except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}")
def is_slash_command(self) -> bool:
"""判断是否为斜杠命令"""
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type
== discord.InteractionType.application_command
)
def is_button_interaction(self) -> bool:
"""判断是否为按钮交互"""
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component
)
def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id"""
if self.is_button_interaction():
try:
return self.message_obj.raw_message.data.get("custom_id", "")
except Exception:
pass
return ""
def is_mentioned(self) -> bool:
"""判断机器人是否被@"""
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "mentions"
):
return any(
mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions
)
return False
def get_mention_clean_content(self) -> str:
"""获取去除@后的清洁内容"""
if hasattr(self.message_obj, "raw_message") and hasattr(
self.message_obj.raw_message, "clean_content"
):
return self.message_obj.raw_message.clean_content
return self.message_str

View File

@@ -1,812 +0,0 @@
import asyncio
import base64
import datetime
import os
import re
import uuid
import threading
import aiohttp
import anyio
import quart
from astrbot.api import logger, sp
from astrbot.api.message_components import Plain, Image, At, Record, Video
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.core.utils.io import download_image_by_url
from .downloader import GeweDownloader
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
try:
from .xml_data_parser import GeweDataParser
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
class SimpleGewechatClient:
"""针对 Gewechat 的简单实现。
@author: Soulter
@website: https://github.com/Soulter
"""
def __init__(
self,
base_url: str,
nickname: str,
host: str,
port: int,
event_queue: asyncio.Queue,
):
self.base_url = base_url
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口
self.download_base_url = ":".join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
self.token = None
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.server = quart.Quart(__name__)
self.server.add_url_rule(
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
)
self.server.add_url_rule(
"/astrbot-gewechat/file/<file_token>",
view_func=self._handle_file,
methods=["GET"],
)
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
self.userrealnames = {}
self.shutdown_event = asyncio.Event()
self.staged_files = {}
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
self.lock = asyncio.Lock()
async def get_token_id(self):
"""获取 Gewechat Token。"""
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json()
self.token = json_blob["data"]
logger.info(f"获取到 Gewechat Token: {self.token}")
self.headers = {"X-GEWE-TOKEN": self.token}
async def _convert(self, data: dict) -> AstrBotMessage:
if "TypeName" in data:
type_name = data["TypeName"]
elif "type_name" in data:
type_name = data["type_name"]
else:
raise Exception("无法识别的消息类型")
# 以下没有业务处理,只是避免控制台打印太多的日志
if type_name == "ModContacts":
logger.info("gewechat下发ModContacts消息通知。")
return
if type_name == "DelContacts":
logger.info("gewechat下发DelContacts消息通知。")
return
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
d = None
if "Data" in data:
d = data["Data"]
elif "data" in data:
d = data["data"]
if not d:
logger.warning(f"消息不含 data 字段: {data}")
return
if "CreateTime" in d:
# 得到系统 UTF+8 的 ts
tz_offset = datetime.timedelta(hours=8)
tz = datetime.timezone(tz_offset)
ts = datetime.datetime.now(tz).timestamp()
create_time = d["CreateTime"]
if create_time < ts - 30:
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
return
abm = AstrBotMessage()
from_user_name = d["FromUserName"]["string"] # 消息来源
d["to_wxid"] = from_user_name # 用于发信息
abm.message_id = str(d.get("MsgId"))
abm.session_id = from_user_name
abm.self_id = data["Wxid"] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d["Content"]["string"] # 消息内容
at_me = False
at_wxids = []
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(":\n")
user_id = _t[0]
content = _t[1]
# at
msg_source = d["MsgSource"]
if "\u2005" in content:
# at
# content = content.split('\u2005')[1]
content = re.sub(r"@[^\u2005]*\u2005", "", content)
at_wxids = re.findall(
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
msg_source,
)
abm.group_id = from_user_name
if (
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
):
at_me = True
if "在群聊中@了你" in d.get("PushContent", ""):
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
# 检查消息是否由自己发送,若是则忽略
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
# if user_id == abm.self_id:
# logger.info("忽略自己发送的消息")
# return None
abm.message = []
# 解析用户真实名字
user_real_name = "unknown"
if abm.group_id:
if (
abm.group_id not in self.userrealnames
or user_id not in self.userrealnames[abm.group_id]
):
# 获取群成员列表,并且缓存
if abm.group_id not in self.userrealnames:
self.userrealnames[abm.group_id] = {}
member_list = await self.get_chatroom_member_list(abm.group_id)
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
if member_list and "memberList" in member_list:
for member in member_list["memberList"]:
self.userrealnames[abm.group_id][member["wxid"]] = member[
"nickName"
]
if user_id in self.userrealnames[abm.group_id]:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
try:
info = (await self.get_user_or_group_info(user_id))["data"][0]
user_real_name = info["nickName"]
except Exception as e:
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
user_real_name = user_id
if at_me:
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
for wxid in at_wxids:
# 群聊里 At 其他人的列表
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
abm.message.append(At(qq=wxid, name=_username))
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
if user_id == "weixin":
# 忽略微信团队消息
return
# 不同消息类型
match d["MsgType"]:
case 1:
# 文本消息
abm.message.append(Plain(content))
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid, content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir, f"gewe_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
case 37: # 好友申请
logger.info("消息类型(37):好友申请")
case 42: # 名片
logger.info("消息类型(42):名片")
case 43: # 视频
video = Video(file="", cover=content)
abm.message.append(video)
case 47: # emoji
data_parser = GeweDataParser(content, abm.group_id == "")
emoji = data_parser.parse_emoji()
abm.message.append(emoji)
case 48: # 地理位置
logger.info("消息类型(48):地理位置")
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
data_parser = GeweDataParser(content, abm.group_id == "")
segments = data_parser.parse_mutil_49()
if segments:
abm.message.extend(segments)
for seg in segments:
if isinstance(seg, Plain):
abm.message_str += seg.text
case 51: # 帐号消息同步?
logger.info("消息类型(51):帐号消息同步?")
case 10000: # 被踢出群聊/更换群主/修改群名称
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
logger.info(
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
)
case _:
logger.info(f"未实现的消息类型: {d['MsgType']}")
abm.raw_message = d
logger.debug(f"abm: {abm}")
return abm
async def _callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
if data.get("testMsg", None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(
f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}"
)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def _register_file(self, file_path: str) -> str:
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
Args:
file_path (str): 文件路径。
Returns:
str: 返回一个 auth_token文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
"""
async with self.lock:
if not os.path.exists(file_path):
raise Exception(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
self.staged_files[file_token] = file_path
return file_token
async def _handle_file(self, file_token):
async with self.lock:
if file_token not in self.staged_files:
logger.warning(f"请求的文件 {file_token} 不存在。")
return quart.abort(404)
if not os.path.exists(self.staged_files[file_token]):
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
return quart.abort(404)
file_path = self.staged_files[file_token]
self.staged_files.pop(file_token, None)
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={"token": self.token, "callbackUrl": self.callback_url},
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob["ret"] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(
f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。"
)
async def start_polling(self):
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host="0.0.0.0",
port=self.port,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger(self):
await self.shutdown_event.wait()
async def check_online(self, appid: str):
"""检查 APPID 对应的设备是否在线。"""
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkOnline",
headers=self.headers,
json={"appId": appid},
) as resp:
json_blob = await resp.json()
return json_blob["data"]
async def logout(self):
"""登出 gewechat。"""
if self.appid:
online = await self.check_online(self.appid)
if online:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/logout",
headers=self.headers,
json={"appId": self.appid},
) as resp:
json_blob = await resp.json()
logger.info(f"登出结果: {json_blob}")
async def login(self):
"""登录 gewechat。一般来说插件用不到这个方法。"""
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(
self.base_url, self.download_base_url, self.token
)
if self.appid:
try:
online = await self.check_online(self.appid)
if online:
logger.info(f"APPID: {self.appid} 已在线")
return
except Exception as e:
logger.error(f"检查在线状态失败: {e}")
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
payload = {"appId": self.appid}
if self.appid:
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/getLoginQrCode",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
if json_blob["ret"] != 200:
error_msg = json_blob.get("data", {}).get("msg", "")
if "设备不存在" in error_msg:
logger.error(
f"检测到无效的appid: {self.appid},将清除并重新登录。"
)
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
return await self.login()
else:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob["data"]["qrData"]
qr_uuid = json_blob["data"]["uuid"]
appid = json_blob["data"]["appId"]
logger.info(f"APPID: {appid}")
logger.warning(
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
)
except Exception as e:
raise e
# 执行登录
retry_cnt = 64
payload.update({"uuid": qr_uuid, "appId": appid})
while retry_cnt > 0:
retry_cnt -= 1
# 需要验证码
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
code_file_path = os.path.join(temp_dir, "gewe_code")
if os.path.exists(code_file_path):
with open(code_file_path, "r") as f:
code = f.read().strip()
if not code:
logger.warning(
"未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
)
await asyncio.sleep(5)
continue
payload["captchCode"] = code
logger.info(f"使用验证码: {code}")
try:
os.remove(code_file_path)
except Exception:
logger.warning(f"删除验证码文件 {code_file_path} 失败。")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkLogin",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(f"检查登录状态: {json_blob}")
ret = json_blob["ret"]
msg = ""
if json_blob["data"] and "msg" in json_blob["data"]:
msg = json_blob["data"]["msg"]
if ret == 500 and "安全验证码" in msg:
logger.warning(
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
)
else:
if "status" in json_blob["data"]:
status = json_blob["data"]["status"]
nickname = json_blob["data"].get("nickName", "")
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{self.nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
"""
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
"""获取群成员列表。
Args:
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
Returns:
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
"""
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
return json_blob["data"]
async def post_text(self, to_wxid, content: str, ats: str = ""):
"""发送纯文本消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"content": content,
}
if ats:
payload["ats"] = ats
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postText", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
"""发送图片消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
"""发送emoji消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"emojiMd5": emoji_md5,
"emojiSize": emoji_size,
}
# 优先表情包若拿不到表情包的md5就用当作图片发
try:
if emoji_md5 != "" and emoji_size != "":
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postEmoji",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
)
else:
await self.post_image(to_wxid, cdnurl)
except Exception as e:
logger.error(e)
async def post_video(
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"videoUrl": video_url,
"thumbUrl": thumb_url,
"videoDuration": video_duration,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送视频结果: {json_blob}")
async def forward_video(self, to_wxid, cnd_xml: str):
"""转发视频
Args:
to_wxid (str): 发送给谁
cnd_xml (str): 视频消息的cdn信息
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"xml": cnd_xml,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/forwardVideo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"转发视频结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
"""发送语音信息
Args:
voice_url (str): 语音文件的网络链接
voice_duration (int): 语音时长,毫秒
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration,
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
async def post_file(self, to_wxid, file_url: str, file_name: str):
"""发送文件
Args:
to_wxid (string): 微信ID
file_url (str): 文件的网络链接
file_name (str): 文件名
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"fileUrl": file_url,
"fileName": file_name,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postFile", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送文件结果: {json_blob}")
async def add_friend(self, v3: str, v4: str, content: str):
"""申请添加好友"""
payload = {
"appId": self.appid,
"scene": 3,
"content": content,
"v4": v4,
"v3": v3,
"option": 2,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/addContacts",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"申请添加好友结果: {json_blob}")
return json_blob
async def get_group(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_group_member(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def accept_group_invite(self, url: str):
"""同意进群"""
payload = {"appId": self.appid, "url": url}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/agreeJoinRoom",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def add_group_member_to_friend(
self, group_id: str, to_wxid: str, content: str
):
payload = {
"appId": self.appid,
"chatroomId": group_id,
"content": content,
"memberWxid": to_wxid,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/addGroupMemberAsFriend",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_user_or_group_info(self, *ids):
"""
获取用户或群组信息。
:param ids: 可变数量的 wxid 参数
"""
wxids_str = list(ids)
payload = {
"appId": self.appid,
"wxids": wxids_str, # 使用逗号分隔的字符串
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/getDetailInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_contacts_list(self):
"""
获取通讯录列表
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
"""
payload = {"appId": self.appid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/fetchContactsList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取通讯录列表结果: {json_blob}")
return json_blob

View File

@@ -1,55 +0,0 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader:
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}", headers=self.headers, json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {"appId": appid, "xml": xml, "msgId": msg_id}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
"""返回一个可下载的 URL"""
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {"appId": appid, "xml": xml, "type": choice}
data = await self._post_json(
self.base_url, "/message/downloadImage", payload
)
json_blob = json.loads(data)
if "fileUrl" in json_blob["data"]:
return self.download_base_url + json_blob["data"]["fileUrl"]
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")
async def download_emoji_md5(self, app_id, emoji_md5):
"""下载emoji"""
try:
payload = {"appId": app_id, "emojiMd5": emoji_md5}
# gewe 计划中的接口暂时没有实现。返回代码404
data = await self._post_json(
self.base_url, "/message/downloadEmojiMd5", payload
)
json_blob = json.loads(data)
return json_blob
except BaseException as e:
logger.error(f"gewe download emoji: {e}")

View File

@@ -1,264 +0,0 @@
import asyncio
import re
import wave
import uuid
import traceback
import os
from typing import AsyncGenerator
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
from astrbot.api.message_components import (
Plain,
Image,
Record,
At,
File,
Video,
WechatEmoji as Emoji,
)
from .client import SimpleGewechatClient
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
def get_wav_duration(file_path):
with wave.open(file_path, "rb") as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
elif n_frames == 0:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleGewechatClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
):
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
# 检查@
ats = []
ats_names = []
for comp in message.chain:
if isinstance(comp, At):
ats.append(comp.qq)
ats_names.append(comp.name)
has_at = False
for comp in message.chain:
if isinstance(comp, Plain):
text = comp.text
payload = {
"to_wxid": to_wxid,
"content": text,
}
if not has_at and ats:
ats = f"{','.join(ats)}"
ats_names = f"@{' @'.join(ats_names)}"
text = f"{ats_names} {text}"
payload["content"] = text
payload["ats"] = ats
has_at = True
await client.post_text(**payload)
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
token = await client._register_file(img_path)
img_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback img url: {img_url}")
await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
if comp.cover != "":
await client.forward_video(to_wxid, comp.cover)
else:
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
video_url = comp.file
# 根据 url 下载视频
if video_url.startswith("http"):
video_filename = f"{uuid.uuid4()}.mp4"
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
video_path = os.path.join(temp_dir, video_filename)
await download_file(video_url, video_path)
else:
video_path = video_url
video_token = await client._register_file(video_path)
video_callback_url = f"{client.file_server_url}/{video_token}"
# 获取视频第一帧
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
thumb_path = os.path.join(
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
)
video_path = video_path.replace(" ", "\\ ")
try:
ff = FFmpeg()
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
ff.options(command)
thumb_token = await client._register_file(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_token}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
# 发送视频
await client.post_video(
to_wxid, video_callback_url, thumb_url, video_duration
)
# 删除临时缩略图文件
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = await comp.convert_to_file_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
try:
duration = await wav_to_tencent_silk(record_path, silk_path)
except Exception as e:
logger.error(traceback.format_exc())
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
logger.info("Silk 语音文件格式转换至: " + record_path)
if duration == 0:
duration = get_wav_duration(record_path)
token = await client._register_file(silk_path)
record_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback record url: {record_url}")
await client.post_voice(to_wxid, record_url, duration * 1000)
elif isinstance(comp, File):
file_path = comp.file
file_name = comp.name
if file_path.startswith("file:///"):
file_path = file_path[8:]
elif file_path.startswith("http"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
temp_file_path = os.path.join(temp_dir, file_name)
await download_file(file_path, temp_file_path)
file_path = temp_file_path
else:
file_path = file_path
token = await client._register_file(file_path)
file_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback file url: {file_url}")
await client.post_file(to_wxid, file_url, file_name)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At):
pass
else:
logger.debug(f"gewechat 忽略: {comp.type}")
async def send(self, message: MessageChain):
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
await super().send(message)
async def get_group(self, group_id=None, **kwargs):
# 确定有效的 group_id
if group_id is None:
group_id = self.get_group_id()
if not group_id:
return None
res = await self.client.get_group(group_id)
data: dict = res["data"]
if not data["chatroomId"]:
return None
members = [
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
for member in data.get("memberList", [])
]
return Group(
group_id=data["chatroomId"],
group_name=data.get("nickName"),
group_avatar=data.get("smallHeadImgUrl"),
group_owner=data.get("chatRoomOwner"),
members=members,
)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -1,103 +0,0 @@
import sys
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
class GewechatPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
self.client = None
self.client = SimpleGewechatClient(
self.config["base_url"],
self.config["nickname"],
self.config["host"],
self.config["port"],
self._event_queue,
)
async def on_event_received(abm: AstrBotMessage):
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
@override
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
session_id = session.session_id
if "#" in session_id:
# unique session
to_wxid = session_id.split("#")[1]
else:
to_wxid = session_id
await GewechatPlatformEvent.send_with_client(
message_chain, to_wxid, self.client
)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="gewechat",
description="基于 gewechat 的 Wechat 适配器",
id=self.config.get("id"),
)
async def terminate(self):
self.client.shutdown_event.set()
try:
await self.client.server.shutdown()
except Exception as _:
pass
logger.info("Gewechat 适配器已被优雅地关闭。")
async def logout(self):
await self.client.logout()
@override
def run(self):
return self._run()
async def _run(self):
await self.client.login()
await self.client.start_polling()
async def handle_msg(self, message: AstrBotMessage):
if message.type == MessageType.GROUP_MESSAGE:
if self.settingss["unique_session"]:
message.session_id = message.sender.user_id + "#" + message.group_id
message_event = GewechatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
)
self.commit_event(message_event)
def get_client(self) -> SimpleGewechatClient:
return self.client

View File

@@ -1,110 +0,0 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import (
WechatEmoji as Emoji,
Reply,
Plain,
BaseMessageComponent,
)
class GeweDataParser:
def __init__(self, data, is_private_chat):
self.data = data
self.is_private_chat = is_private_chat
def _format_to_xml(self):
return eT.fromstring(self.data)
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
appmsg_type = self._format_to_xml().find(".//appmsg/type")
if appmsg_type is None:
return
match appmsg_type.text:
case "57":
return self.parse_reply()
def parse_emoji(self) -> Emoji | None:
try:
emoji_element = self._format_to_xml().find(".//emoji")
# 提取 md5 和 len 属性
if emoji_element is not None:
md5_value = emoji_element.get("md5")
emoji_size = emoji_element.get("len")
cdnurl = emoji_element.get("cdnurl")
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
except Exception as e:
logger.error(f"gewechat: parse_emoji failed, {e}")
def parse_reply(self) -> list[Reply, Plain] | None:
"""解析引用消息
Returns:
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
"""
try:
replied_id = -1
replied_uid = 0
replied_nickname = ""
replied_content = "" # 被引用者说的内容
content = "" # 引用者说的内容
root = self._format_to_xml()
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 被引用的信息
svrid = refermsg.find("svrid")
fromusr = refermsg.find("fromusr")
displayname = refermsg.find("displayname")
refermsg_content = refermsg.find("content")
if svrid is not None:
replied_id = svrid.text
if fromusr is not None:
replied_uid = fromusr.text
if displayname is not None:
replied_nickname = displayname.text
if refermsg_content is not None:
# 处理引用嵌套,包括嵌套公众号消息
if refermsg_content.text.startswith(
"<msg>"
) or refermsg_content.text.startswith("<?xml"):
try:
logger.debug("gewechat: Reference message is nested")
refer_root = eT.fromstring(refermsg_content.text)
img = refer_root.find("img")
if img is not None:
replied_content = "[图片]"
else:
app_msg = refer_root.find("appmsg")
refermsg_content_title = app_msg.find("title")
logger.debug(
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
)
replied_content = refermsg_content_title.text
except Exception as e:
logger.error(f"gewechat: nested failed, {e}")
# 处理异常情况
replied_content = refermsg_content.text
else:
replied_content = refermsg_content.text
# 提取引用者说的内容
title = root.find(".//appmsg/title")
if title is not None:
content = title.text
reply_seg = Reply(
id=replied_id,
chain=[Plain(replied_content)],
sender_id=replied_uid,
sender_nickname=replied_nickname,
message_str=replied_content,
)
plain_seg = Plain(content)
return [reply_seg, plain_seg]
except Exception as e:
logger.error(f"gewechat: parse_reply failed, {e}")

View File

@@ -28,10 +28,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
await self._post_send()
async def send_streaming(self, generator, use_fallback: bool = False):
"""流式输出仅支持消息列表私聊"""

View File

@@ -0,0 +1,162 @@
import json
import hmac
import hashlib
import asyncio
import logging
from typing import Callable, Optional
from quart import Quart, request, Response
from slack_sdk.web.async_client import AsyncWebClient
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from astrbot.api import logger
class SlackWebhookClient:
"""Slack Webhook 模式客户端,使用 Quart 作为 Web 服务器"""
def __init__(
self,
web_client: AsyncWebClient,
signing_secret: str,
host: str = "0.0.0.0",
port: int = 3000,
path: str = "/slack/events",
event_handler: Optional[Callable] = None,
):
self.web_client = web_client
self.signing_secret = signing_secret
self.host = host
self.port = port
self.path = path
self.event_handler = event_handler
self.app = Quart(__name__)
self._setup_routes()
# 禁用 Quart 的默认日志输出
logging.getLogger("quart.app").setLevel(logging.WARNING)
logging.getLogger("quart.serving").setLevel(logging.WARNING)
self.shutdown_event = asyncio.Event()
def _setup_routes(self):
"""设置路由"""
@self.app.route(self.path, methods=["POST"])
async def slack_events():
"""处理 Slack 事件"""
try:
# 获取请求体和头部
body = await request.get_data()
event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature
timestamp = request.headers.get("X-Slack-Request-Timestamp")
signature = request.headers.get("X-Slack-Signature")
if not timestamp or not signature:
return Response("Missing headers", status=400)
# Calculate the HMAC signature
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
my_signature = (
"v0="
+ hmac.new(
self.signing_secret.encode("utf-8"),
sig_basestring.encode("utf-8"),
hashlib.sha256,
).hexdigest()
)
# Verify the signature
if not hmac.compare_digest(my_signature, signature):
logger.warning("Slack request signature verification failed")
return Response("Invalid signature", status=400)
logger.info(f"Received Slack event: {event_data}")
# 处理 URL 验证事件
if event_data.get("type") == "url_verification":
return {"challenge": event_data.get("challenge")}
# 处理事件
if self.event_handler and event_data.get("type") == "event_callback":
await self.event_handler(event_data)
return Response("", status=200)
except Exception as e:
logger.error(f"处理 Slack 事件时出错: {e}")
return Response("Internal Server Error", status=500)
@self.app.route("/health", methods=["GET"])
async def health_check():
"""健康检查端点"""
return {"status": "ok", "service": "slack-webhook"}
async def start(self):
"""启动 Webhook 服务器"""
logger.info(
f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}..."
)
await self.app.run_task(
host=self.host,
port=self.port,
debug=False,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger(self):
await self.shutdown_event.wait()
async def stop(self):
"""停止 Webhook 服务器"""
self.shutdown_event.set()
logger.info("Slack Webhook 服务器已停止")
class SlackSocketClient:
"""Slack Socket 模式客户端"""
def __init__(
self,
web_client: AsyncWebClient,
app_token: str,
event_handler: Optional[Callable] = None,
):
self.web_client = web_client
self.app_token = app_token
self.event_handler = event_handler
self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
"""处理 Socket Mode 事件"""
try:
# 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response)
# 处理事件
if self.event_handler:
await self.event_handler(req)
except Exception as e:
logger.error(f"处理 Socket Mode 事件时出错: {e}")
async def start(self):
"""启动 Socket Mode 连接"""
self.socket_client = SocketModeClient(
app_token=self.app_token,
logger=logger,
web_client=self.web_client,
)
# 注册事件处理器
self.socket_client.socket_mode_request_listeners.append(self._handle_events)
logger.info("Slack Socket Mode 客户端启动中...")
await self.socket_client.connect()
async def stop(self):
"""停止 Socket Mode 连接"""
if self.socket_client:
await self.socket_client.disconnect()
await self.socket_client.close()
logger.info("Slack Socket Mode 客户端已停止")

View File

@@ -0,0 +1,398 @@
import time
import asyncio
import uuid
import aiohttp
import re
import base64
from typing import Awaitable, Any
from slack_sdk.web.async_client import AsyncWebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from .slack_event import SlackMessageEvent
from .client import SlackWebhookClient, SlackSocketClient
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
@register_platform_adapter(
"slack", "适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。"
)
class SlackAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
self.bot_token = platform_config.get("bot_token")
self.app_token = platform_config.get("app_token")
self.signing_secret = platform_config.get("signing_secret")
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
self.webhook_path = platform_config.get(
"slack_webhook_path", "/astrbot-slack-webhook/callback"
)
if not self.bot_token:
raise ValueError("Slack bot_token 是必需的")
if self.connection_mode == "socket" and not self.app_token:
raise ValueError("Socket Mode 需要 app_token")
if self.connection_mode == "webhook" and not self.signing_secret:
raise ValueError("Webhook Mode 需要 signing_secret")
self.metadata = PlatformMetadata(
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
)
# 初始化 Slack Web Client
self.web_client = AsyncWebClient(token=self.bot_token, logger=logger)
self.socket_client = None
self.webhook_client = None
self.bot_self_id = None
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
blocks, text = SlackMessageEvent._parse_slack_blocks(
message_chain=message_chain, web_client=self.web_client
)
try:
if session.message_type == MessageType.GROUP_MESSAGE:
# 发送到频道
channel_id = (
session.session_id.split("_")[-1]
if "_" in session.session_id
else session.session_id
)
await self.web_client.chat_postMessage(
channel=channel_id,
text=text,
blocks=blocks if blocks else None,
)
else:
# 发送私信
await self.web_client.chat_postMessage(
channel=session.session_id,
text=text,
blocks=blocks if blocks else None,
)
except Exception as e:
logger.error(f"Slack 发送消息失败: {e}")
await super().send_by_session(session, message_chain)
async def convert_message(self, event: dict) -> AstrBotMessage:
logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage()
abm.self_id = self.bot_self_id
# 获取用户信息
user_id = event.get("user", "")
try:
user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"]
user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception:
user_name = user_id
abm.sender = MessageMember(user_id=user_id, nickname=user_name)
# 判断消息类型
channel_id = event.get("channel", "")
try:
channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"]
if is_im:
abm.type = MessageType.FRIEND_MESSAGE
else:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = channel_id
except Exception:
# 默认作为群组消息处理
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = channel_id
# 设置会话ID
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{user_id}_{channel_id}"
else:
abm.session_id = (
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
)
abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
abm.timestamp = int(float(event.get("ts", time.time())))
# 处理消息内容
message_text = event.get("text", "")
abm.message_str = message_text
abm.message = []
# 优先使用 blocks 字段解析消息
if "blocks" in event and event["blocks"]:
abm.message = self._parse_blocks(event["blocks"])
# 更新 message_str
abm.message_str = ""
for component in abm.message:
if isinstance(component, Plain):
abm.message_str += component.text
elif message_text:
# 处理传统的文本消息
if "<@" in message_text:
mentions = re.findall(r"<@([^>]+)>", message_text)
for mention in mentions:
try:
mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"]
user_name = user_data.get("real_name") or user_data.get(
"name", mention
)
abm.message.append(At(qq=mention, name=user_name))
except Exception:
abm.message.append(At(qq=mention, name=""))
# 清理消息文本中的@标记
if clean_text := re.sub(r"<@[^>]+>", "", message_text).strip():
abm.message.append(Plain(text=clean_text))
else:
abm.message.append(Plain(text=message_text))
# 处理文件附件
if "files" in event:
for file_info in event["files"]:
file_name = file_info.get("name", "unknown")
file_url = file_info.get("url_private", "")
if file_info.get("mimetype", "").startswith("image/"):
file_url = await self.get_file_base64(file_url)
abm.message.append(Image.fromBase64(base64=file_url))
else:
# TODO: 下载鉴权
abm.message.append(
File(name=file_name, file=file_url, url=file_url)
)
abm.raw_message = event
return abm
def _parse_blocks(self, blocks: list) -> list:
"""解析 Slack blocks 格式的消息内容"""
message_components = []
for block in blocks:
block_type = block.get("type", "")
if block_type == "rich_text":
# 处理富文本块
elements = block.get("elements", [])
for element in elements:
if element.get("type") == "rich_text_section":
# 处理富文本段落
section_elements = element.get("elements", [])
text_content = ""
for section_element in section_elements:
element_type = section_element.get("type", "")
if element_type == "text":
# 普通文本
text_content += section_element.get("text", "")
elif element_type == "user":
# @用户提及
user_id = section_element.get("user_id", "")
if user_id:
# 将之前的文本内容先添加到组件中
if text_content.strip():
message_components.append(
Plain(text=text_content)
)
text_content = ""
# 添加@提及组件
message_components.append(At(qq=user_id, name=""))
elif element_type == "channel":
# #频道提及
channel_id = section_element.get("channel_id", "")
text_content += f"#{channel_id}"
elif element_type == "link":
# 链接
url = section_element.get("url", "")
link_text = section_element.get("text", url)
text_content += f"[{link_text}]({url})"
elif element_type == "emoji":
# 表情符号
emoji_name = section_element.get("name", "")
text_content += f":{emoji_name}:"
if text_content.strip():
message_components.append(Plain(text=text_content))
elif element.get("type") == "rich_text_list":
# 处理列表
list_items = element.get("elements", [])
list_text = ""
for item in list_items:
if item.get("type") == "rich_text_section":
item_elements = item.get("elements", [])
item_text = ""
for item_element in item_elements:
if item_element.get("type") == "text":
item_text += item_element.get("text", "")
list_text += f"{item_text}\n"
if list_text.strip():
message_components.append(Plain(text=list_text.strip()))
elif block_type == "section":
# 处理段落块
if "text" in block:
text_obj = block["text"]
if text_obj.get("type") == "mrkdwn":
text_content = text_obj.get("text", "")
message_components.append(Plain(text=text_content))
return message_components
async def _handle_socket_event(self, req: SocketModeRequest):
"""处理 Socket Mode 事件"""
if req.type == "events_api":
# 事件 API
event = req.payload.get("event", {})
# 忽略机器人自己的消息和消息编辑
if event.get("subtype") in [
"bot_message",
"message_changed",
"message_deleted",
]:
return
if event.get("bot_id"):
return
if event.get("type") in ["message", "app_mention"]:
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
async def get_bot_user_id(self):
auth_info = await self.web_client.auth_test()
return auth_info.get("user_id")
async def get_file_base64(self, url: str) -> str:
"""下载 Slack 文件并返回 Base64 编码的内容"""
headers = {"Authorization": f"Bearer {self.bot_token}"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as resp:
if resp.status == 200:
content = await resp.read()
base64_content = base64.b64encode(content).decode("utf-8")
return base64_content
else:
logger.error(
f"Failed to download slack file: {resp.status} {await resp.text()}"
)
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:
self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
if self.connection_mode == "socket":
if not self.app_token:
raise ValueError("Socket Mode 需要 app_token")
# 创建 Socket 客户端
self.socket_client = SlackSocketClient(
self.web_client, self.app_token, self._handle_socket_event
)
logger.info("Slack 适配器 (Socket Mode) 启动中...")
await self.socket_client.start()
elif self.connection_mode == "webhook":
if not self.signing_secret:
raise ValueError("Webhook Mode 需要 signing_secret")
# 创建 Webhook 客户端
self.webhook_client = SlackWebhookClient(
self.web_client,
self.signing_secret,
self.webhook_host,
self.webhook_port,
self.webhook_path,
self._handle_webhook_event,
)
logger.info(
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}..."
)
await self.webhook_client.start()
else:
raise ValueError(
f"不支持的连接模式: {self.connection_mode},请使用 'socket''webhook'"
)
async def _handle_webhook_event(self, event_data: dict):
"""处理 Webhook 事件"""
event = event_data.get("event", {})
# 忽略机器人自己的消息和消息编辑
if event.get("subtype") in [
"bot_message",
"message_changed",
"message_deleted",
]:
return
if event.get("bot_id"):
return
if event.get("type") in ["message", "app_mention"]:
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
async def terminate(self):
if self.socket_client:
await self.socket_client.stop()
if self.webhook_client:
await self.webhook_client.stop()
logger.info("Slack 适配器已被优雅地关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = SlackMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
web_client=self.web_client,
)
self.commit_event(message_event)
def get_client(self):
return self.web_client

View File

@@ -0,0 +1,243 @@
import asyncio
import re
from typing import AsyncGenerator
from slack_sdk.web.async_client import AsyncWebClient
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
Plain,
File,
BaseMessageComponent,
)
from astrbot.api.platform import Group, MessageMember
from astrbot.api import logger
class SlackMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str,
message_obj,
platform_meta,
session_id,
web_client: AsyncWebClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.web_client = web_client
@staticmethod
async def _from_segment_to_slack_block(
segment: BaseMessageComponent, web_client: AsyncWebClient
) -> dict:
"""将消息段转换为 Slack 块格式"""
if isinstance(segment, Plain):
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
elif isinstance(segment, Image):
# upload file
url = segment.url or segment.file
if url.startswith("http"):
return {
"type": "image",
"image_url": url,
"alt_text": "图片",
}
path = await segment.convert_to_file_path()
response = await web_client.files_upload_v2(
file=path,
filename="image.jpg",
)
if not response["ok"]:
logger.error(f"Slack file upload failed: {response['error']}")
return {
"type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"},
}
image_url = response["files"][0]["url_private"]
logger.debug(f"Slack file upload response: {response}")
return {
"type": "image",
"slack_file": {
"url": image_url,
},
"alt_text": "图片",
}
elif isinstance(segment, File):
# upload file
url = segment.url or segment.file
response = await web_client.files_upload_v2(
file=url,
filename=segment.name or "file",
)
if not response["ok"]:
logger.error(f"Slack file upload failed: {response['error']}")
return {
"type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
return {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
else:
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
@staticmethod
async def _parse_slack_blocks(
message_chain: MessageChain, web_client: AsyncWebClient
):
"""解析成 Slack 块格式"""
blocks = []
text_content = ""
for segment in message_chain.chain:
if isinstance(segment, Plain):
text_content += segment.text
else:
# 如果有文本内容,先添加文本块
if text_content.strip():
blocks.append(
{
"type": "section",
"text": {"type": "mrkdwn", "text": text_content},
}
)
text_content = ""
# 添加其他类型的块
block = await SlackMessageEvent._from_segment_to_slack_block(
segment, web_client
)
blocks.append(block)
# 如果最后还有文本内容
if text_content.strip():
blocks.append(
{"type": "section", "text": {"type": "mrkdwn", "text": text_content}}
)
return blocks, "" if blocks else text_content
async def send(self, message: MessageChain):
blocks, text = await SlackMessageEvent._parse_slack_blocks(
message, self.web_client
)
try:
if self.get_group_id():
# 发送到频道
await self.web_client.chat_postMessage(
channel=self.get_group_id(),
text=text,
blocks=blocks or None,
)
else:
# 发送私信
await self.web_client.chat_postMessage(
channel=self.get_sender_id(),
text=text,
blocks=blocks or None,
)
except Exception:
# 如果块发送失败,尝试只发送文本
fallback_text = ""
for segment in message.chain:
if isinstance(segment, Plain):
fallback_text += segment.text
elif isinstance(segment, File):
fallback_text += f" [文件: {segment.name}] "
elif isinstance(segment, Image):
fallback_text += " [图片] "
if self.get_group_id():
await self.web_client.chat_postMessage(
channel=self.get_group_id(), text=fallback_text
)
else:
await self.web_client.chat_postMessage(
channel=self.get_sender_id(), text=fallback_text
)
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)
async def get_group(self, group_id=None, **kwargs):
if group_id:
channel_id = group_id
elif self.get_group_id():
channel_id = self.get_group_id()
else:
return None
try:
# 获取频道信息
channel_info = await self.web_client.conversations_info(channel=channel_id)
# 获取频道成员
members_response = await self.web_client.conversations_members(
channel=channel_id
)
members = []
for member_id in members_response["members"]:
try:
user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"]
members.append(
MessageMember(
user_id=member_id,
nickname=user_data.get("real_name")
or user_data.get("name", member_id),
)
)
except Exception:
# 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"]
return Group(
group_id=channel_id,
group_name=channel_data.get("name", ""),
group_avatar="",
group_admins=[], # Slack 的管理员信息需要特殊权限获取
group_owner=channel_data.get("creator", ""),
members=members,
)
except Exception:
return None

View File

@@ -144,8 +144,8 @@ class TelegramPlatformAdapter(Platform):
command_dict = {}
skip_commands = {"start"}
for handler_md in star_handlers_registry._handlers:
handler_metadata = handler_md[1]
for handler_md in star_handlers_registry:
handler_metadata = handler_md
if not star_map[handler_metadata.handler_module_path].activated:
continue
for event_filter in handler_metadata.event_filters:

View File

@@ -1,4 +1,5 @@
import os
import re
import asyncio
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -18,6 +19,16 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class TelegramPlatformEvent(AstrMessageEvent):
# Telegram 的最大消息长度限制
MAX_MESSAGE_LENGTH = 4096
SPLIT_PATTERNS = {
"paragraph": re.compile(r"\n\n"),
"line": re.compile(r"\n"),
"sentence": re.compile(r"[.!?。!?]"),
"word": re.compile(r"\s"),
}
def __init__(
self,
message_str: str,
@@ -29,8 +40,35 @@ class TelegramPlatformEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
@classmethod
def _split_message(cls, text: str) -> list[str]:
if len(text) <= cls.MAX_MESSAGE_LENGTH:
return [text]
chunks = []
while text:
if len(text) <= cls.MAX_MESSAGE_LENGTH:
chunks.append(text)
break
split_point = cls.MAX_MESSAGE_LENGTH
segment = text[: cls.MAX_MESSAGE_LENGTH]
for _, pattern in cls.SPLIT_PATTERNS.items():
if matches := list(pattern.finditer(segment)):
last_match = matches[-1]
split_point = last_match.end()
break
chunks.append(text[:split_point])
text = text[split_point:].lstrip()
return chunks
@classmethod
async def send_with_client(
cls, client: ExtBot, message: MessageChain, user_name: str
):
image_path = None
has_reply = False
@@ -59,19 +97,22 @@ class TelegramPlatformEvent(AstrMessageEvent):
if isinstance(i, Plain):
if at_user_id and not at_flag:
i.text = f"@{at_user_id} " + i.text
i.text = f"@{at_user_id} {i.text}"
at_flag = True
text = i.text
chunks = cls._split_message(i.text)
for chunk in chunks:
try:
text = telegramify_markdown.markdownify(
i.text, max_line_length=None, normalize_whitespace=False
md_text = telegramify_markdown.markdownify(
chunk, max_line_length=None, normalize_whitespace=False
)
await client.send_message(
text=md_text, parse_mode="MarkdownV2", **payload
)
except Exception as e:
logger.warning(
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
f"MarkdownV2 send failed: {e}. Using plain text instead."
)
return
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
await client.send_message(text=chunk, **payload)
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload)
@@ -119,6 +160,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
async for chain in generator:
if isinstance(chain, MessageChain):
if chain.type == "break":
# 分割符
message_id = None # 重置消息 ID
delta = "" # 重置 delta
continue
# 处理消息链中的每个组件
for i in chain.chain:
if isinstance(i, Plain):
@@ -147,17 +194,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
continue
# Plain
if not message_id:
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
else:
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
@@ -176,6 +213,18 @@ class TelegramPlatformEvent(AstrMessageEvent):
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
else:
# delta 长度一般不会大于 4096因此这里直接发送
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
delta = ""
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
try:
if delta and current_content != delta:

View File

@@ -2,7 +2,7 @@ import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from typing import Awaitable, Any, Callable
from astrbot.core.platform import (
Platform,
AstrBotMessage,
@@ -13,7 +13,7 @@ from astrbot.core.platform import (
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
from astrbot import logger
from astrbot.core import web_chat_queue
from .webchat_queue_mgr import webchat_queue_mgr, WebChatQueueMgr
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
@@ -21,14 +21,46 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
self.webchat_queue_mgr = webchat_queue_mgr
self.callback = callback
self.running_tasks = set()
async def listen_to_queue(self, conversation_id: str):
"""Listen to a specific conversation queue"""
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
while True:
try:
data = await queue.get()
await self.callback(data)
except Exception as e:
logger.error(
f"Error processing message from conversation {conversation_id}: {e}"
)
break
async def run(self):
"""Monitor for new conversation queues and start listeners"""
monitored_conversations = set()
while True:
data = await self.queue.get()
await self.callback(data)
# Check for new conversations
current_conversations = set(self.webchat_queue_mgr.queues.keys())
new_conversations = current_conversations - monitored_conversations
# Start listeners for new conversations
for conversation_id in new_conversations:
task = asyncio.create_task(self.listen_to_queue(conversation_id))
self.running_tasks.add(task)
task.add_done_callback(self.running_tasks.discard)
monitored_conversations.add(conversation_id)
logger.debug(f"Started listener for conversation: {conversation_id}")
# Clean up monitored conversations that no longer exist
removed_conversations = monitored_conversations - current_conversations
monitored_conversations -= removed_conversations
await asyncio.sleep(1) # Check for new conversations every second
@register_platform_adapter("webchat", "webchat")
@@ -45,7 +77,7 @@ class WebChatAdapter(Platform):
os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat", description="webchat", id=self.config.get("id")
name="webchat", description="webchat", id=self.config.get("id", "")
)
async def send_by_session(
@@ -105,7 +137,7 @@ class WebChatAdapter(Platform):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
bot = QueueListener(webchat_queue_mgr, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
@@ -119,6 +151,10 @@ class WebChatAdapter(Platform):
session_id=message.session_id,
)
_, _, payload = message.raw_message # type: ignore
message_event.set_extra("selected_provider", payload.get("selected_provider"))
message_event.set_extra("selected_model", payload.get("selected_model"))
self.commit_event(message_event)
async def terminate(self):

View File

@@ -5,8 +5,8 @@ from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import web_chat_back_queue
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
@@ -18,13 +18,18 @@ class WebChatMessageEvent(AstrMessageEvent):
@staticmethod
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
cid = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
if not message:
await web_chat_back_queue.put(
{"type": "end", "data": "", "streaming": False}
{
"type": "end",
"data": "",
"streaming": False,
} # end means this request is finished
)
return ""
cid = session_id.split("!")[-1]
data = ""
for comp in message.chain:
if isinstance(comp, Plain):
@@ -35,6 +40,7 @@ class WebChatMessageEvent(AstrMessageEvent):
"cid": cid,
"data": data,
"streaming": streaming,
"chain_type": message.type,
}
)
elif isinstance(comp, Image):
@@ -97,29 +103,35 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await web_chat_back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
"cid": self.session_id.split("!")[-1],
}
)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator:
if chain.type == "break" and final_data:
# 分割符
await web_chat_back_queue.put(
{
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"cid": cid,
}
)
final_data = ""
continue
final_data += await WebChatMessageEvent._send(
chain, session_id=self.session_id, streaming=True
)
await web_chat_back_queue.put(
{
"type": "end",
"type": "complete", # complete means we return the final result
"data": final_data,
"streaming": True,
"cid": self.session_id.split("!")[-1],
"cid": cid,
}
)
await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,35 @@
import asyncio
class WebChatQueueMgr:
def __init__(self) -> None:
self.queues = {}
"""Conversation ID to asyncio.Queue mapping"""
self.back_queues = {}
"""Conversation ID to asyncio.Queue mapping for responses"""
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a queue for the given conversation ID"""
if conversation_id not in self.queues:
self.queues[conversation_id] = asyncio.Queue()
return self.queues[conversation_id]
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
"""Get or create a back queue for the given conversation ID"""
if conversation_id not in self.back_queues:
self.back_queues[conversation_id] = asyncio.Queue()
return self.back_queues[conversation_id]
def remove_queues(self, conversation_id: str):
"""Remove queues for the given conversation ID"""
if conversation_id in self.queues:
del self.queues[conversation_id]
if conversation_id in self.back_queues:
del self.back_queues[conversation_id]
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
webchat_queue_mgr = WebChatQueueMgr()

View File

@@ -1,14 +1,16 @@
import asyncio
import base64
import json
import os
import traceback
import time
from typing import Optional
import aiohttp
import anyio
import websockets
from astrbot import logger
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api.platform import Platform, PlatformMetadata
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astrbot_message import (
@@ -22,6 +24,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .wechatpadpro_message_event import WeChatPadProMessageEvent
try:
from .xml_data_parser import GeweDataParser
except ImportError as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
class WeChatPadProAdapter(Platform):
@@ -59,6 +68,18 @@ class WeChatPadProAdapter(Platform):
) # 持久化文件路径
self.ws_handle_task = None
# 添加图片消息缓存,用于引用消息处理
self.cached_images = {}
"""缓存图片消息。key是NewMsgId (对应引用消息的svrid)value是图片的base64数据"""
# 设置缓存大小限制,避免内存占用过大
self.max_image_cache = 50
# 添加文本消息缓存,用于引用消息处理
self.cached_texts = {}
"""缓存文本消息。key是NewMsgId (对应引用消息的svrid)value是消息文本内容"""
# 设置文本缓存大小限制
self.max_text_cache = 100
async def run(self) -> None:
"""
启动平台适配器的运行实例。
@@ -138,7 +159,6 @@ class WeChatPadProAdapter(Platform):
os.makedirs(data_dir, exist_ok=True)
with open(self.credentials_file, "w") as f:
json.dump(credentials, f)
logger.info("成功保存 WeChatPadPro 凭据。")
except Exception as e:
logger.error(f"保存 WeChatPadPro 凭据失败: {e}")
@@ -146,6 +166,8 @@ class WeChatPadProAdapter(Platform):
"""
检查 WeChatPadPro 设备是否在线。
"""
if not self.auth_key:
return False
url = f"{self.base_url}/login/GetLoginStatus"
params = {"key": self.auth_key}
@@ -161,20 +183,18 @@ class WeChatPadProAdapter(Platform):
return True
# login_state == 3 为离线状态
elif login_state == 3:
logger.info(
"WeChatPadPro 设备不在线。"
)
logger.info("WeChatPadPro 设备不在线。")
return False
else:
logger.error(
f"未知的在线状态: {login_state:}"
)
logger.error(f"未知的在线状态: {response_data}")
return False
# Code == 300 为微信退出状态。
elif response.status == 200 and response_data.get("Code") == 300:
logger.info(
"WeChatPadPro 设备已退出。"
)
logger.info("WeChatPadPro 设备已退出。")
return False
elif response.status == 200 and response_data.get("Code") == -2:
# 该链接不存在
self.auth_key = None
return False
else:
logger.error(
@@ -187,8 +207,19 @@ class WeChatPadProAdapter(Platform):
return False
except Exception as e:
logger.error(f"检查在线状态时发生错误: {e}")
logger.error(traceback.format_exc())
return False
def _extract_auth_key(self, data):
"""Helper method to extract auth_key from response data."""
if isinstance(data, dict):
auth_keys = data.get("authKeys") # 新接口
if isinstance(auth_keys, list) and auth_keys:
return auth_keys[0]
elif isinstance(data, list) and data: # 旧接口
return data[0]
return None
async def generate_auth_key(self):
"""
生成授权码。
@@ -197,28 +228,30 @@ class WeChatPadProAdapter(Platform):
params = {"key": self.admin_key}
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
self.auth_key = None # Reset auth_key before generating a new one
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(
f"生成授权码失败: {response.status}, {await response.text()}"
)
return
response_data = await response.json()
# 修正成功判断条件和授权码提取路径
if response.status == 200 and response_data.get("Code") == 200:
# 授权码在 Data 字段的列表中
if (
response_data.get("Data")
and isinstance(response_data["Data"], list)
and len(response_data["Data"]) > 0
):
self.auth_key = response_data["Data"][0]
if response_data.get("Code") == 200:
if data := response_data.get("Data"):
self.auth_key = self._extract_auth_key(data)
if self.auth_key:
logger.info("成功获取授权码")
else:
logger.error(
f"生成授权码成功但未找到授权码: {response_data}"
)
else:
logger.error(
f"生成授权码失败: {response.status}, {response_data}"
)
logger.error(f"生成授权码失败: {response_data}")
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
except Exception as e:
@@ -236,7 +269,6 @@ class WeChatPadProAdapter(Platform):
try:
async with session.post(url, params=params, json=payload) as response:
response_data = await response.json()
# 修正成功判断条件和数据提取路径
if response.status == 200 and response_data.get("Code") == 200:
# 二维码地址在 Data.QrCodeUrl 字段中
if response_data.get("Data") and response_data["Data"].get(
@@ -248,6 +280,13 @@ class WeChatPadProAdapter(Platform):
f"获取登录二维码成功但未找到二维码地址: {response_data}"
)
return None
elif "该 key 无效" in response_data.get("Text"):
logger.error(
"授权码无效,已经清除。请重新启动 AstrBot 或者本消息适配器。原因也可能是 WeChatPadPro 的 MySQL 服务没有启动成功,请检查 WeChatPadPro 服务的日志。"
)
self.auth_key = None
self.save_credentials()
return None
else:
logger.error(
f"获取登录二维码失败: {response.status}, {response_data}"
@@ -340,7 +379,7 @@ class WeChatPadProAdapter(Platform):
while True:
try:
async with websockets.connect(ws_url) as websocket:
logger.info("WebSocket 连接成功。")
logger.debug("WebSocket 连接成功。")
# 设置空闲超时重连
wait_time = (
self.active_message_poll_interval
@@ -355,9 +394,7 @@ class WeChatPadProAdapter(Platform):
# logger.debug(message) # 不显示原始消息内容
asyncio.create_task(self.handle_websocket_message(message))
except asyncio.TimeoutError:
logger.warning(
f"WebSocket 连接空闲超过 {wait_time} s"
)
logger.debug(f"WebSocket 连接空闲超过 {wait_time} s")
break
except websockets.exceptions.ConnectionClosedOK:
logger.info("WebSocket 连接正常关闭。")
@@ -366,7 +403,9 @@ class WeChatPadProAdapter(Platform):
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
break
except Exception as e:
logger.error(f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态或尝试重启WeChatPadPro适配器。")
logger.error(
f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态或尝试重启WeChatPadPro适配器。"
)
await asyncio.sleep(5)
async def handle_websocket_message(self, message: str):
@@ -459,6 +498,7 @@ class WeChatPadProAdapter(Platform):
"""
if from_user_name == "weixin":
return False
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = from_user_name
@@ -477,9 +517,17 @@ class WeChatPadProAdapter(Platform):
# 对于群聊session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
if self.unique_session:
abm.session_id = f"{from_user_name}_{to_user_name}"
abm.session_id = f"{from_user_name}#{abm.sender.user_id}"
else:
abm.session_id = from_user_name
msg_source = raw_message.get("msg_source", "")
if self.wxid in msg_source:
at_me = True
if "在群聊中@了你" in raw_message.get("push_content", ""):
at_me = True
if at_me:
abm.message.insert(0, At(qq=abm.self_id, name=""))
else:
abm.type = MessageType.FRIEND_MESSAGE
abm.group_id = ""
@@ -560,6 +608,32 @@ class WeChatPadProAdapter(Platform):
logger.error(f"下载图片时发生错误: {e}")
return None
async def download_voice(
self, to_user_name: str, new_msg_id: str, bufid: str, length: int
):
"""下载原始音频。"""
url = f"{self.base_url}/message/GetMsgVoice"
params = {"key": self.auth_key}
payload = {
"Bufid": bufid,
"ToUserName": to_user_name,
"NewMsgId": new_msg_id,
"Length": length,
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status == 200:
return await response.json()
logger.error(f"下载音频失败: {response.status}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"下载音频时发生错误: {e}")
return None
async def _process_message_content(
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
):
@@ -571,12 +645,82 @@ class WeChatPadProAdapter(Platform):
if abm.type == MessageType.GROUP_MESSAGE:
parts = content.split(":\n", 1)
if len(parts) == 2:
abm.message_str = parts[1]
abm.message.append(Plain(abm.message_str))
message_content = parts[1]
abm.message_str = message_content
# 检查是否@了机器人,参考 gewechat 的实现方式
# 微信大部分客户端在@用户昵称后面,紧接着是一个\u2005字符四分之一空格
at_me = False
# 检查 msg_source 中是否包含机器人的 wxid
# wechatpadpro 的格式: <atuserlist>wxid</atuserlist>
# gewechat 的格式: <atuserlist><![CDATA[wxid]]></atuserlist>
msg_source = raw_message.get("msg_source", "")
if (
f"<atuserlist>{abm.self_id}</atuserlist>" in msg_source
or f"<atuserlist>{abm.self_id}," in msg_source
or f",{abm.self_id}</atuserlist>" in msg_source
):
at_me = True
# 也检查 push_content 中是否有@提示
push_content = raw_message.get("push_content", "")
if "在群聊中@了你" in push_content:
at_me = True
if at_me:
# 被@了在消息开头插入At组件参考gewechat的做法
bot_nickname = await self._get_group_member_nickname(
abm.group_id, abm.self_id
)
abm.message.insert(
0, At(qq=abm.self_id, name=bot_nickname or abm.self_id)
)
# 只有当消息内容不仅仅是@时才添加Plain组件
if "\u2005" in message_content:
# 检查@之后是否还有其他内容
parts = message_content.split("\u2005")
if len(parts) > 1 and any(
part.strip() for part in parts[1:]
):
abm.message.append(Plain(message_content))
else:
# 检查是否只包含@机器人
is_pure_at = False
if (
bot_nickname
and message_content.strip() == f"@{bot_nickname}"
):
is_pure_at = True
if not is_pure_at:
abm.message.append(Plain(message_content))
else:
# 没有@机器人,作为普通文本处理
abm.message.append(Plain(message_content))
else:
abm.message.append(Plain(abm.message_str))
else: # 私聊消息
abm.message.append(Plain(abm.message_str))
# 缓存文本消息,以便引用消息可以查找
try:
# 获取msg_id作为缓存的key
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id:
# 限制缓存大小
if (
len(self.cached_texts) >= self.max_text_cache
and self.cached_texts
):
# 删除最早的一条缓存
oldest_key = next(iter(self.cached_texts))
self.cached_texts.pop(oldest_key)
logger.debug(f"缓存文本消息new_msg_id={new_msg_id}")
self.cached_texts[str(new_msg_id)] = content
except Exception as e:
logger.error(f"缓存文本消息失败: {e}")
elif msg_type == 3:
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
@@ -590,15 +734,87 @@ class WeChatPadProAdapter(Platform):
)
if image_bs64_data:
abm.message.append(Image.fromBase64(image_bs64_data))
# 缓存图片,以便引用消息可以查找
try:
# 获取msg_id作为缓存的key
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id:
# 限制缓存大小
if (
len(self.cached_images) >= self.max_image_cache
and self.cached_images
):
# 删除最早的一条缓存
oldest_key = next(iter(self.cached_images))
self.cached_images.pop(oldest_key)
logger.debug(f"缓存图片消息new_msg_id={new_msg_id}")
self.cached_images[str(new_msg_id)] = image_bs64_data
except Exception as e:
logger.error(f"缓存图片消息失败: {e}")
elif msg_type == 47:
# 视频消息 (注意:表情消息也是 47需要区分)
logger.warning("收到视频消息,待实现。")
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
raw_message=raw_message,
)
emoji_message = data_parser.parse_emoji()
if emoji_message is not None:
abm.message.append(emoji_message)
elif msg_type == 50:
# 语音/视频
logger.warning("收到语音/视频消息,待实现。")
elif msg_type == 34:
# 语音消息
bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id")
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
raw_message=raw_message,
)
voicemsg = data_parser._format_to_xml().find("voicemsg")
bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice(
to_user_name=to_user_name,
new_msg_id=new_msg_id,
bufid=bufid,
length=length,
)
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir, f"wechatpadpro_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_bs64_data)
abm.message.append(Record(file=file_path, url=file_path))
elif msg_type == 49:
# 引用消息
logger.warning("收到引用消息,待实现。")
try:
parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
cached_texts=self.cached_texts,
cached_images=self.cached_images,
raw_message=raw_message,
downloader=self._download_raw_image,
)
components = await parser.parse_mutil_49()
if components:
abm.message.extend(components)
abm.message_str = "\n".join(
c.text for c in components if isinstance(c, Plain)
)
except Exception as e:
logger.warning(f"msg_type 49 处理失败: {e}")
abm.message.append(Plain("[XML 消息处理失败]"))
abm.message_str = "[XML 消息处理失败]"
else:
logger.warning(f"收到未处理的消息类型: {msg_type}")
@@ -628,6 +844,9 @@ class WeChatPadProAdapter(Platform):
# 根据 session_id 判断消息类型
if "@chatroom" in session.session_id:
dummy_message_obj.type = MessageType.GROUP_MESSAGE
if "#" in session.session_id:
dummy_message_obj.group_id = session.session_id.split("#")[0]
else:
dummy_message_obj.group_id = session.session_id
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
else:
@@ -643,3 +862,67 @@ class WeChatPadProAdapter(Platform):
)
# 调用实例方法 send
await sending_event.send(message_chain)
async def get_contact_list(self):
"""
获取联系人列表。
"""
url = f"{self.base_url}/friend/GetContactList"
params = {"key": self.auth_key}
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = (
result.get("Data", {})
.get("ContactList", {})
.get("contactUsernameList", [])
)
return contact_list
else:
logger.error(f"获取联系人列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人列表时发生错误: {e}")
return None
async def get_contact_details_list(
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
) -> Optional[dict]:
"""
获取联系人详情列表。
"""
if room_wx_id_list is None:
room_wx_id_list = []
if user_names is None:
user_names = []
url = f"{self.base_url}/friend/GetContactDetailsList"
params = {"key": self.auth_key}
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人详情列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = result.get("Data", {}).get("contactList", {})
return contact_list
else:
logger.error(f"获取联系人详情列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人详情列表时发生错误: {e}")
return None

View File

@@ -7,11 +7,17 @@ import aiohttp
from PIL import Image as PILImage # 使用别名避免冲突
from astrbot import logger
from astrbot.core.message.components import Image, Plain # Import Image
from astrbot.core.message.components import (
Image,
Plain,
WechatEmoji,
Record,
) # Import Image
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.utils.tencent_record_helper import audio_to_tencent_silk_base64
if TYPE_CHECKING:
from .wechatpadpro_adapter import WeChatPadProAdapter
@@ -38,6 +44,10 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
await self._send_text(session, comp.text)
elif isinstance(comp, Image):
await self._send_image(session, comp)
elif isinstance(comp, WechatEmoji):
await self._send_emoji(session, comp)
elif isinstance(comp, Record):
await self._send_voice(session, comp)
await super().send(message)
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
@@ -71,14 +81,48 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
# logger.info(f"已添加 @ 信息: {message_text}")
else:
message_text = text
if self.get_group_id() and "#" in self.session_id:
session_id = self.session_id.split("#")[0]
else:
session_id = self.session_id
payload = {
"MsgItem": [
{"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id}
{
"MsgType": 1,
"TextContent": message_text,
"ToUserName": session_id,
}
]
}
url = f"{self.adapter.base_url}/message/SendTextMessage"
await self._post(session, url, payload)
async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji):
payload = {
"EmojiList": [
{
"EmojiMd5": comp.md5,
"EmojiSize": comp.md5_len,
"ToUserName": self.session_id,
}
]
}
url = f"{self.adapter.base_url}/message/SendEmojiMessage"
await self._post(session, url, payload)
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
record_path = await comp.convert_to_file_path()
# 默认已经存在 data/temp 中
b64, duration = await audio_to_tencent_silk_base64(record_path)
payload = {
"ToUserName": self.session_id,
"VoiceData": b64,
"VoiceFormat": 4,
"VoiceSecond": duration,
}
url = f"{self.adapter.base_url}/message/SendVoice"
await self._post(session, url, payload)
@staticmethod
def _validate_base64(b64: str) -> bytes:
return base64.b64decode(b64, validate=True)

View File

@@ -0,0 +1,160 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import (
WechatEmoji as Emoji,
Plain,
Image,
BaseMessageComponent,
)
class GeweDataParser:
def __init__(
self,
content: str,
is_private_chat: bool = False,
cached_texts=None,
cached_images=None,
raw_message: dict = None,
downloader=None,
):
self._xml = None
self.content = content
self.is_private_chat = is_private_chat
self.cached_texts = cached_texts or {}
self.cached_images = cached_images or {}
self.downloader = downloader
raw_message = raw_message or {}
self.from_user_name = raw_message.get("from_user_name", {}).get("str", "")
self.to_user_name = raw_message.get("to_user_name", {}).get("str", "")
self.msg_id = raw_message.get("msg_id", "")
def _format_to_xml(self):
if self._xml:
return self._xml
try:
msg_str = self.content
if not self.is_private_chat:
parts = self.content.split(":\n", 1)
msg_str = parts[1] if len(parts) == 2 else self.content
self._xml = eT.fromstring(msg_str)
return self._xml
except Exception as e:
logger.error(f"[XML解析失败] {e}")
raise
async def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
"""
处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57
"""
try:
appmsg_type = self._format_to_xml().findtext(".//appmsg/type")
if appmsg_type == "57":
return await self.parse_reply()
except Exception as e:
logger.warning(f"[parse_mutil_49] 解析失败: {e}")
return None
async def parse_reply(self) -> list[BaseMessageComponent]:
"""
处理 type == 57 的引用消息支持文本1、图片3、嵌套4949
"""
components = []
try:
appmsg = self._format_to_xml().find("appmsg")
if appmsg is None:
return [Plain("[引用消息解析失败]")]
refermsg = appmsg.find("refermsg")
if refermsg is None:
return [Plain("[引用消息解析失败]")]
quote_type = int(refermsg.findtext("type", "0"))
nickname = refermsg.findtext("displayname", "未知发送者")
quote_content = refermsg.findtext("content", "")
svrid = refermsg.findtext("svrid")
match quote_type:
case 1: # 文本引用
quoted_text = self.cached_texts.get(str(svrid), quote_content)
components.append(Plain(f"[引用] {nickname}: {quoted_text}"))
case 3: # 图片引用
quoted_image_b64 = self.cached_images.get(str(svrid))
if not quoted_image_b64:
try:
quote_xml = eT.fromstring(quote_content)
img = quote_xml.find("img")
cdn_url = (
img.get("cdnbigimgurl") or img.get("cdnmidimgurl")
if img is not None
else None
)
if cdn_url and self.downloader:
image_resp = await self.downloader(
self.from_user_name, self.to_user_name, self.msg_id
)
quoted_image_b64 = (
image_resp.get("Data", {})
.get("Data", {})
.get("Buffer")
)
except Exception as e:
logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}")
if quoted_image_b64:
components.extend(
[
Image.fromBase64(quoted_image_b64),
Plain(f"[引用] {nickname}: [引用的图片]"),
]
)
else:
components.append(
Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]")
)
case 49: # 嵌套引用
try:
nested_root = eT.fromstring(quote_content)
nested_title = nested_root.findtext(".//appmsg/title", "")
components.append(Plain(f"[引用] {nickname}: {nested_title}"))
except Exception as e:
logger.warning(f"[嵌套引用解析失败] err={e}")
components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]"))
case _: # 其他未识别类型
logger.info(f"[未知引用类型] quote_type={quote_type}")
components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]"))
# 主消息标题
title = appmsg.findtext("title", "")
if title:
components.append(Plain(title))
except Exception as e:
logger.error(f"[parse_reply] 总体解析失败: {e}")
return [Plain("[引用消息解析失败]")]
return components
def parse_emoji(self) -> Emoji | None:
"""
处理 msg_type == 47 的表情消息emoji
"""
try:
emoji_element = self._format_to_xml().find(".//emoji")
if emoji_element is not None:
return Emoji(
md5=emoji_element.get("md5"),
md5_len=emoji_element.get("len"),
cdnurl=emoji_element.get("cdnurl"),
)
except Exception as e:
logger.error(f"[parse_emoji] 解析失败: {e}")
return None

View File

@@ -303,6 +303,7 @@ class WecomPlatformAdapter(Platform):
abm.session_id = external_userid
abm.type = MessageType.FRIEND_MESSAGE
abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8])
abm.message_str = ""
if msgtype == "text":
text = msg.get("text", {}).get("content", "").strip()
abm.message = [Plain(text=text)]
@@ -316,7 +317,29 @@ class WecomPlatformAdapter(Platform):
with open(path, "wb") as f:
f.write(resp.content)
abm.message = [Image(file=path, url=path)]
abm.message_str = "[图片]"
elif msgtype == "voice":
media_id = msg.get("voice", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
None, self.client.media.download, media_id
)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr")
with open(path, "wb") as f:
f.write(resp.content)
try:
from pydub import AudioSegment
path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav")
audio = AudioSegment.from_file(path)
audio.export(path_wav, format="wav")
except Exception as e:
logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。")
path_wav = path
return
abm.message = [Record(file=path_wav, url=path_wav)]
else:
logger.warning(f"未实现的微信客服消息事件: {msg}")
return

View File

@@ -120,6 +120,30 @@ class WecomPlatformEvent(AstrMessageEvent):
self.get_self_id(),
response["media_id"],
)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
# 转成amr
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
pydub.AudioSegment.from_wav(record_path).export(
record_path_amr, format="amr"
)
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信客服上传语音失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传语音失败: {e}")
)
return
logger.info(f"微信客服上传语音返回: {response}")
kf_message_api.send_voice(
user_id,
self.get_self_id(),
response["media_id"],
)
else:
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")
else:

View File

@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。
:return: 接口调用结果
"""
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
data = {
"token": token,
"cursor": cursor,
"limit": limit,
"open_kfid": open_kfid,
}
return self._post("kf/sync_msg", data=data)
def get_service_state(self, open_kfid, external_userid):
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
}
return self._post("kf/service_state/get", data=data)
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
def trans_service_state(
self, open_kfid, external_userid, service_state, servicer_userid=""
):
"""
变更会话状态
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
"""
return self._get("kf/customer/get_upgrade_service_config")
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
def upgrade_service(
self, open_kfid, external_userid, service_type, member=None, groupchat=None
):
"""
为客户升级为专员或客户群服务
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
return self._post("kf/get_corp_statistic", data=data)
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
def get_servicer_statistic(
self, start_time, end_time, open_kfid=None, servicer_userid=None
):
"""
获取「客户数据统计」接待人员明细数据

View File

@@ -26,6 +26,7 @@ from optionaldict import optionaldict
from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKFMessage(BaseWeChatAPI):
"""
发送微信客服消息
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
msg={"msgtype": "news", "link": {"link": articles_data}},
)
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
def send_msgmenu(
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "msgmenu",
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
"msgmenu": {
"head_content": head_content,
"list": menu_list,
"tail_content": tail_content,
},
},
)
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
def send_location(
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "location",
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
"msgmenu": {
"name": name,
"address": address,
"latitude": latitude,
"longitude": longitude,
},
},
)
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
def send_miniprogram(
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "miniprogram",
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
"msgmenu": {
"appid": appid,
"title": title,
"thumb_media_id": thumb_media_id,
"pagepath": pagepath,
},
},
)

View File

@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
result = await asyncio.wait_for(
asyncio.shield(future), 60
) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py

View File

@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
return
logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,

View File

@@ -19,6 +19,7 @@ class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
EMBEDDING = "embedding"
@dataclass
@@ -57,7 +58,7 @@ class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
role: str = "assistant"
def to_dict(self):
@@ -66,7 +67,7 @@ class AssistantMessageSegment:
}
if self.content:
ret["content"] = self.content
elif self.tool_calls:
if self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
@@ -94,27 +95,38 @@ class ProviderRequest:
"""提示词"""
session_id: str = ""
"""会话 ID"""
image_urls: List[str] = None
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
func_tool: FuncCall = None
func_tool: FuncCall | None = None
"""可用的函数工具"""
contexts: List = None
contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation = None
conversation: Conversation | None = None
tool_calls_result: ToolCallsResult = None
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
model: str | None = None
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self):
return self.__repr__()
def append_tool_calls_result(self, tool_calls_result: ToolCallsResult):
"""添加工具调用结果到请求中"""
if not self.tool_calls_result:
self.tool_calls_result = []
if isinstance(self.tool_calls_result, ToolCallsResult):
self.tool_calls_result = [self.tool_calls_result]
self.tool_calls_result.append(tool_calls_result)
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts:
@@ -155,7 +167,9 @@ class ProviderRequest:
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
],
}
for image_url in self.image_urls:
if image_url.startswith("http"):

View File

@@ -4,6 +4,7 @@ import textwrap
import os
import asyncio
import logging
from datetime import timedelta
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass
@@ -20,6 +21,13 @@ try:
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
@@ -31,6 +39,72 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
first_key = next(iter(config["mcpServers"]))
config = config["mcpServers"][first_key]
config.pop("active", None)
return config
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""快速测试 MCP 服务器可达性"""
import aiohttp
cfg = _prepare_config(config.copy())
url = cfg["url"]
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)
try:
async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
}
async with session.post(
url,
headers={
**headers,
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
},
json=test_payload,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
else:
async with session.get(
url,
headers={
**headers,
"Accept": "application/json, text/event-stream",
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 200:
return True, ""
else:
return False, f"HTTP {response.status}: {response.reason}"
except asyncio.TimeoutError:
return False, f"连接超时: {timeout}"
except Exception as e:
return False, f"{e!s}"
@dataclass
class FuncTool:
"""
@@ -72,12 +146,10 @@ class FuncTool:
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
if ":" in self.name:
# 如果名字是格式为 mcp:server:tool_name提取实际的工具名
actual_tool_name = self.name.split(":")[-1]
actual_tool_name = (
self.name.split(":")[-1] if ":" in self.name else self.name
)
return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
return await self.mcp_client.session.call_tool(self.name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
@@ -92,30 +164,77 @@ class MCPClient:
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = []
self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
如果 `url` 参数存在
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = mcp_server_config.copy()
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
key_0 = list(cfg["mcpServers"].keys())[0]
cfg = cfg["mcpServers"][key_0]
cfg.pop("active", None) # Remove active flag from config
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
# 处理 MCP 服务的错误日志
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
if not success:
raise Exception(error_msg)
if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(url=cfg["url"])
streams = await self._streams_context.__aenter__()
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
# self.session = await self._session_context.__aenter__()
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*streams)
mcp.ClientSession(
*streams,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self.exit_stack.enter_async_context(
self._streams_context
)
# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
write_stream=write_s,
read_timeout_seconds=read_timeout,
logging_callback=logging_callback, # type: ignore
)
)
else:
@@ -135,7 +254,7 @@ class MCPClient:
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
), # type: ignore
),
)
@@ -143,19 +262,18 @@ class MCPClient:
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
self.running_event.set() # Set the running event to indicate cleanup is done
class FuncCall:
@@ -164,8 +282,6 @@ class FuncCall:
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_service_queue = asyncio.Queue()
"""用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool:
@@ -221,7 +337,7 @@ class FuncCall:
return f
return None
async def _init_mcp_clients(self) -> None:
async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
@@ -263,72 +379,32 @@ class FuncCall:
)
self.mcp_client_event[name] = event
async def mcp_service_selector(self):
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
{"type": "init"} 初始化所有MCP客户端
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
{"type": "terminate"} 终止所有MCP客户端
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
"""
while True:
data = await self.mcp_service_queue.get()
if data["type"] == "init":
if "name" in data:
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(
data["name"], data["cfg"], event
)
)
self.mcp_client_event[data["name"]] = event
else:
await self._init_mcp_clients()
elif data["type"] == "terminate":
if "name" in data:
# await self._terminate_mcp_client(data["name"])
if data["name"] in self.mcp_client_event:
self.mcp_client_event[data["name"]].set()
self.mcp_client_event.pop(data["name"], None)
self.func_list = [
f
for f in self.func_list
if not (
f.origin == "mcp" and f.mcp_server_name == data["name"]
)
]
else:
for name in self.mcp_client_dict.keys():
# await self._terminate_mcp_client(name)
# self.mcp_client_event[name].set()
if name in self.mcp_client_event:
self.mcp_client_event[name].set()
self.mcp_client_event.pop(name, None)
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
async def _init_mcp_client_task_wrapper(
self, name: str, cfg: dict, event: asyncio.Event
self,
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
# tell the caller we are ready
ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
await self._terminate_mcp_client(name)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
if ready_future and not ready_future.done():
ready_future.set_exception(e)
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
try:
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
@@ -338,6 +414,7 @@ class FuncCall:
self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]
# 移除该MCP服务之前的工具如有
@@ -360,16 +437,6 @@ class FuncCall:
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
# 发生错误时确保客户端被清理
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
return
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
@@ -377,9 +444,9 @@ class FuncCall:
try:
# 关闭MCP连接
await self.mcp_client_dict[name].cleanup()
del self.mcp_client_dict[name]
self.mcp_client_dict.pop(name)
except Exception as e:
logger.info(f"清空 MCP 客户端资源 {name}: {e}")
logger.error(f"清空 MCP 客户端资源 {name}: {e}")
# 移除关联的FuncTool
self.func_list = [
f
@@ -388,6 +455,103 @@ class FuncCall:
]
logger.info(f"已关闭 MCP 服务 {name}")
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:
if "url" in config:
success, error_msg = await _quick_test_mcp_connection(config)
if not success:
raise Exception(error_msg)
mcp_client = MCPClient()
try:
logger.debug(f"testing MCP server connection with config: {config}")
await mcp_client.connect_to_server(config, "test")
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
finally:
logger.debug("Cleaning up MCP client after testing connection.")
await mcp_client.cleanup()
return tool_names
async def enable_mcp_server(
self,
name: str,
config: dict,
event: asyncio.Event | None = None,
ready_future: asyncio.Future | None = None,
timeout: int = 30,
) -> None:
"""Enable_mcp_server a new MCP server to the manager and initialize it.
Args:
name (str): The name of the MCP server.
config (dict): Configuration for the MCP server.
event (asyncio.Event): Event to signal when the MCP client is ready.
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
timeout (int): Timeout for the initialization.
Raises:
TimeoutError: If the initialization does not complete within the specified timeout.
Exception: If there is an error during initialization.
"""
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
return
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
finally:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self, name: str | None = None, timeout: float = 10
) -> None:
"""Disable an MCP server by its name.
Args:
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
timeout (int): Timeout.
"""
if name:
if name not in self.mcp_client_event:
return
client = self.mcp_client_dict.get(name)
self.mcp_client_event[name].set()
if not client:
return
client_running_event = client.running_event
try:
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
finally:
self.mcp_client_event.pop(name, None)
self.func_list = [
f
for f in self.func_list
if f.origin != "mcp" or f.mcp_server_name != name
]
else:
running_events = [
client.running_event.wait() for client in self.mcp_client_dict.values()
]
for key, event in self.mcp_client_event.items():
event.set()
# waiting for all clients to finish
try:
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
@@ -592,8 +756,3 @@ class FuncCall:
def __repr__(self):
return str(self.func_list)
async def terminate(self):
for name in self.mcp_client_dict.keys():
await self._terminate_mcp_client(name)
logger.debug(f"清理 MCP 客户端 {name} 资源")

View File

@@ -1,12 +1,14 @@
import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entities import ProviderType
import traceback
from typing import List
from astrbot.core.db import BaseDatabase
from .register import provider_cls_map, llm_tools
from astrbot.core import logger, sp
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
class ProviderManager:
@@ -18,13 +20,6 @@ class ProviderManager:
self.persona_configs: list = config.get("persona", [])
self.astrbot_config = config
self.selected_provider_id = sp.get("curr_provider")
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
self.provider_enabled = self.provider_settings.get("enable", False)
self.stt_enabled = self.provider_stt_settings.get("enable", False)
self.tts_enabled = self.provider_tts_settings.get("enable", False)
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get(
@@ -98,15 +93,18 @@ class ProviderManager:
"""加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例"""
self.inst_map = {}
self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map: dict[str, Provider] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
"""当前使用的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider = None
"""当前使用的 Speech To Text Provider 实例"""
self.curr_tts_provider_inst: TTSProvider = None
"""当前使用的 Text To Speech Provider 实例"""
self.curr_provider_inst: Provider | None = None
"""默认的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider | None = None
"""默认的 Speech To Text Provider 实例"""
self.curr_tts_provider_inst: TTSProvider | None = None
"""默认的 Text To Speech Provider 实例"""
self.db_helper = db_helper
# kdb(experimental)
@@ -115,24 +113,63 @@ class ProviderManager:
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
async def set_provider(
self, provider_id: str, provider_type: ProviderType, umo: str = None
):
"""设置提供商。
Args:
provider_id (str): 提供商 ID。
provider_type (ProviderType): 提供商类型。
umo (str, optional): 用户会话 ID用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
"""
if provider_id not in self.inst_map:
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
if umo and self.provider_settings["separate_provider"]:
perf = sp.get("session_provider_perf", {})
session_perf = perf.get(umo, {})
session_perf[provider_type.value] = provider_id
perf[umo] = session_perf
sp.put("session_provider_perf", perf)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
sp.put("curr_provider_tts", provider_id)
elif provider_type == ProviderType.SPEECH_TO_TEXT:
sp.put("curr_provider_stt", provider_id)
elif provider_type == ProviderType.CHAT_COMPLETION:
sp.put("curr_provider", provider_id)
async def initialize(self):
# 逐个初始化提供商
for provider_config in self.providers_config:
await self.load_provider(provider_config)
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
# 设置默认提供商
selected_provider_id = sp.get(
"curr_provider", self.provider_settings.get("default_provider_id")
)
selected_stt_provider_id = sp.get(
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
)
selected_tts_provider_id = sp.get(
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
if self.stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接
asyncio.create_task(
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
)
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
@@ -155,11 +192,6 @@ class ProviderManager:
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import (
LLMTunerModelLoader as LLMTunerModelLoader,
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "dashscope":
@@ -190,6 +222,10 @@ class ProviderManager:
from .sources.edge_tts_source import (
ProviderEdgeTTS as ProviderEdgeTTS,
)
case "gsv_tts_selfhost":
from .sources.gsv_selfhosted_source import (
ProviderGSVTTS as ProviderGSVTTS,
)
case "gsvi_tts_api":
from .sources.gsvi_tts_source import (
ProviderGSVITTS as ProviderGSVITTS,
@@ -214,6 +250,18 @@ class ProviderManager:
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
case "gemini_tts":
from .sources.gemini_tts_source import (
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
)
case "openai_embedding":
from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
)
case "gemini_embedding":
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
@@ -246,14 +294,14 @@ class ProviderManager:
self.stt_provider_insts.append(inst)
if (
self.selected_stt_provider_id == provider_config["id"]
and self.stt_enabled
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
)
if not self.curr_stt_provider_inst and self.stt_enabled:
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
@@ -266,15 +314,12 @@ class ProviderManager:
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.selected_tts_provider_id == provider_config["id"]
and self.tts_enabled
):
if self.provider_settings.get("provider_id") == provider_config["id"]:
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
)
if not self.curr_tts_provider_inst and self.tts_enabled:
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
@@ -282,8 +327,6 @@ class ProviderManager:
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.db_helper,
self.provider_settings.get("persistant_history", True),
self.selected_default_persona,
)
@@ -292,16 +335,24 @@ class ProviderManager:
self.provider_insts.append(inst)
if (
self.selected_provider_id == provider_config["id"]
and self.provider_enabled
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
)
if not self.curr_provider_inst and self.provider_enabled:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
logger.error(traceback.format_exc())
@@ -322,39 +373,24 @@ class ProviderManager:
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif (
self.curr_provider_inst is None
and len(self.provider_insts) > 0
and self.provider_enabled
):
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
self.selected_provider_id = self.curr_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif (
self.curr_stt_provider_inst is None
and len(self.stt_provider_insts) > 0
and self.stt_enabled
):
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif (
self.curr_tts_provider_inst is None
and len(self.tts_provider_insts) > 0
and self.tts_enabled
):
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
)
@@ -383,7 +419,7 @@ class ProviderManager:
self.curr_tts_provider_inst = None
if getattr(self.inst_map[provider_id], "terminate", None):
await self.inst_map[provider_id].terminate()
await self.inst_map[provider_id].terminate() # type: ignore
logger.info(
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
@@ -393,6 +429,8 @@ class ProviderManager:
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()
# 清理 MCP Client 连接
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
await provider_inst.terminate() # type: ignore
try:
await self.llm_tools.disable_mcp_server()
except Exception:
logger.error("Error while disabling MCP servers", exc_info=True)

View File

@@ -1,9 +1,9 @@
import abc
from typing import List
from astrbot.core.db import BaseDatabase
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
from astrbot.core.provider.register import provider_cls_map
from dataclasses import dataclass
@@ -23,6 +23,7 @@ class ProviderMeta:
id: str
model: str
type: str
provider_type: ProviderType
class AbstractProvider(abc.ABC):
@@ -41,10 +42,14 @@ class AbstractProvider(abc.ABC):
def meta(self) -> ProviderMeta:
"""获取 Provider 的元数据"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
provider_type = meta_data.provider_type if meta_data else None
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
type=self.provider_config["type"],
type=provider_type_name,
provider_type=provider_type,
)
@@ -53,15 +58,13 @@ class Provider(AbstractProvider):
self,
provider_config: dict,
provider_settings: dict,
persistant_history: bool = True,
db_helper: BaseDatabase = None,
default_persona: Personality = None,
default_persona: Personality | None = None,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
self.curr_personality = default_persona
"""维护了当前的使用的 persona即人格。可能为 None"""
@abc.abstractmethod
@@ -86,11 +89,12 @@ class Provider(AbstractProvider):
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -114,11 +118,12 @@ class Provider(AbstractProvider):
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
@@ -179,3 +184,25 @@ class TTSProvider(AbstractProvider):
async def get_audio(self, text: str) -> str:
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError()
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的向量"""
...
@abc.abstractmethod
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的向量"""
...
@abc.abstractmethod
def get_dim(self) -> int:
"""获取向量的维度"""
...

View File

@@ -1,3 +1,6 @@
import json
import anthropic
import base64
from typing import List
from mimetypes import guess_type
@@ -5,41 +8,33 @@ from anthropic import AsyncAnthropic
from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from .openai_source import ProviderOpenAIOfficial
from astrbot.core.provider.entities import LLMResponse
from typing import AsyncGenerator
@register_provider_adapter(
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
)
class ProviderAnthropic(ProviderOpenAIOfficial):
class ProviderAnthropic(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona: Personality = None,
) -> None:
# Skip OpenAI's __init__ and call Provider's __init__ directly
Provider.__init__(
self,
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.chosen_api_key = None
self.chosen_api_key: str = ""
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
@@ -51,10 +46,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
self.set_model(provider_config["model_config"]["model"])
def _prepare_payload(self, messages: list[dict]):
"""准备 Anthropic API 的请求 payload
Args:
messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息
Returns:
system_prompt: 系统提示内容
new_messages: 处理后的消息列表,去除系统提示
"""
system_prompt = ""
new_messages = []
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
elif message["role"] == "assistant":
blocks = []
if isinstance(message["content"], str):
blocks.append({"type": "text", "text": message["content"]})
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
blocks.append( # noqa: PERF401
{
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"],
"id": tool_call["id"],
}
)
new_messages.append(
{
"role": "assistant",
"content": blocks,
}
)
elif message["role"] == "tool":
new_messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"],
}
],
}
)
else:
new_messages.append(message)
return system_prompt, new_messages
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_anthropic_style()
if tool_list:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
completion = await self.client.messages.create(**payloads, stream=False)
@@ -64,68 +112,158 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if len(completion.content) == 0:
raise Exception("API 返回的 completion 为空。")
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
# 选最后一条消息如果要进行函数调用anthropic会先返回文本消息的思维链然后再返回函数调用请求
content = completion.content[-1]
llm_response = LLMResponse("assistant")
llm_response = LLMResponse(role="assistant")
if content.type == "text":
# text completion
completion_text = str(content.text).strip()
# llm_response.completion_text = completion_text
llm_response.result_chain = MessageChain().message(completion_text)
# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_use_ids = []
func_name_ls.append(content.name)
args_ls.append(content.input)
tool_use_ids.append(content.id)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_use_ids
for content_block in completion.content:
if content_block.type == "text":
completion_text = str(content_block.text).strip()
llm_response.completion_text = completion_text
if content_block.type == "tool_use":
llm_response.tools_call_args.append(content_block.input)
llm_response.tools_call_name.append(content_block.name)
llm_response.tools_call_ids.append(content_block.id)
# TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
) -> AsyncGenerator[LLMResponse, None]:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
# 用于累积工具调用信息
tool_use_buffer = {}
# 用于累积最终结果
final_text = ""
final_tool_calls = []
async with self.client.messages.stream(**payloads) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
if event.type == "content_block_start":
if event.content_block.type == "text":
# 文本块开始
yield LLMResponse(
role="assistant", completion_text="", is_chunk=True
)
elif event.content_block.type == "tool_use":
# 工具使用块开始,初始化缓冲区
tool_use_buffer[event.index] = {
"id": event.content_block.id,
"name": event.content_block.name,
"input": {},
}
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
# 文本增量
final_text += event.delta.text
yield LLMResponse(
role="assistant",
completion_text=event.delta.text,
is_chunk=True,
)
elif event.delta.type == "input_json_delta":
# 工具调用参数增量
if event.index in tool_use_buffer:
# 累积 JSON 输入
if "input_json" not in tool_use_buffer[event.index]:
tool_use_buffer[event.index]["input_json"] = ""
tool_use_buffer[event.index]["input_json"] += (
event.delta.partial_json
)
elif event.type == "content_block_stop":
# 内容块结束
if event.index in tool_use_buffer:
# 解析完整的工具调用
tool_info = tool_use_buffer[event.index]
try:
if "input_json" in tool_info:
tool_info["input"] = json.loads(tool_info["input_json"])
# 添加到最终结果
final_tool_calls.append(
{
"id": tool_info["id"],
"name": tool_info["name"],
"input": tool_info["input"],
}
)
yield LLMResponse(
role="tool",
completion_text="",
tools_call_args=[tool_info["input"]],
tools_call_name=[tool_info["name"]],
tools_call_ids=[tool_info["id"]],
is_chunk=True,
)
except json.JSONDecodeError:
# JSON 解析失败,跳过这个工具调用
logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}")
# 清理缓冲区
del tool_use_buffer[event.index]
# 返回最终的完整结果
final_response = LLMResponse(
role="assistant", completion_text=final_text, is_chunk=False
)
if final_tool_calls:
final_response.tools_call_args = [
call["input"] for call in final_tool_calls
]
final_response.tools_call_name = [call["name"] for call in final_tool_calls]
final_response.tools_call_ids = [call["id"] for call in final_tool_calls]
yield final_response
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
prompt,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result: ToolCallsResult = None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if not prompt:
prompt = "<image>"
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
# 暂时这样写。
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
payloads = {"messages": context_query, **model_config}
# Anthropic has a different way of handling system prompts
if system_prompt:
payloads["system"] = system_prompt
@@ -133,30 +271,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
try:
await self.pop_record(context_query)
response = await self.client.messages.create(
messages=context_query, **model_config
)
llm_response = LLMResponse("assistant")
llm_response.result_chain = MessageChain().message(response.content[0].text)
llm_response.raw_completion = response
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
else:
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
@@ -171,22 +286,40 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
# Anthropic has a different way of handling system prompts
if system_prompt:
payloads["system"] = system_prompt
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
@@ -230,3 +363,28 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
)
return {"role": "user", "content": content}
async def encode_image_bs64(self, image_url: str) -> str:
"""
将图片转换为 base64
"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
def get_current_key(self) -> str:
return self.chosen_api_key
async def get_models(self) -> List[str]:
models_str = []
models = await self.client.models.list()
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
return models_str
def set_key(self, key: str):
self.chosen_api_key = key

View File

@@ -19,6 +19,7 @@ from ..register import register_provider_adapter
TEMP_DIR = Path("data/temp/azure_tts")
TEMP_DIR.mkdir(parents=True, exist_ok=True)
class OTTSProvider:
def __init__(self, config: Dict):
self.skey = config["OTTS_SKEY"]
@@ -70,12 +71,12 @@ class OTTSProvider:
"style": voice_params["style"],
"role": voice_params["role"],
"rate": voice_params["rate"],
"volume": voice_params["volume"]
"volume": voice_params["volume"],
},
headers={
"User-Agent": f"AstrBot/{VERSION}",
"UAK": "AstrBot/AzureTTS"
}
"UAK": "AstrBot/AzureTTS",
},
)
response.raise_for_status()
file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -88,14 +89,19 @@ class OTTSProvider:
raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
await asyncio.sleep(0.5 * (attempt + 1))
class AzureNativeProvider(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict):
super().__init__(provider_config, provider_settings)
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
self.subscription_key = provider_config.get(
"azure_tts_subscription_key", ""
).strip()
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
raise ValueError("无效的Azure订阅密钥")
self.region = provider_config.get("azure_tts_region", "eastus").strip()
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
)
self.client = None
self.token = None
self.token_expire = 0
@@ -104,15 +110,17 @@ class AzureNativeProvider(TTSProvider):
"style": provider_config.get("azure_tts_style", "cheerful"),
"role": provider_config.get("azure_tts_role", "Boy"),
"rate": provider_config.get("azure_tts_rate", "1"),
"volume": provider_config.get("azure_tts_volume", "100")
"volume": provider_config.get("azure_tts_volume", "100"),
}
async def __aenter__(self):
self.client = AsyncClient(headers={
self.client = AsyncClient(
headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm"
})
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
}
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -120,10 +128,11 @@ class AzureNativeProvider(TTSProvider):
await self.client.aclose()
async def _refresh_token(self):
token_url = f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
token_url = (
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
)
response = await self.client.post(
token_url,
headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
)
response.raise_for_status()
self.token = response.text
@@ -150,8 +159,8 @@ class AzureNativeProvider(TTSProvider):
content=ssml,
headers={
"Authorization": f"Bearer {self.token}",
"User-Agent": f"AstrBot/{VERSION}"
}
"User-Agent": f"AstrBot/{VERSION}",
},
)
response.raise_for_status()
file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -160,6 +169,7 @@ class AzureNativeProvider(TTSProvider):
f.write(chunk)
return str(file_path.resolve())
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
class AzureTTSProvider(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict):
@@ -183,7 +193,7 @@ class AzureTTSProvider(TTSProvider):
error_msg = (
f"JSON解析失败请检查格式错误位置{e.lineno}{e.colno}\n"
f"错误详情: {e.msg}\n"
f"错误上下文: {json_str[max(0, e.pos-30):e.pos+30]}"
f"错误上下文: {json_str[max(0, e.pos - 30) : e.pos + 30]}"
)
raise ValueError(error_msg) from e
except KeyError as e:
@@ -202,8 +212,8 @@ class AzureTTSProvider(TTSProvider):
"style": self.provider_config.get("azure_tts_style"),
"role": self.provider_config.get("azure_tts_role"),
"rate": self.provider_config.get("azure_tts_rate"),
"volume": self.provider_config.get("azure_tts_volume")
}
"volume": self.provider_config.get("azure_tts_volume"),
},
)
else:
async with self.provider as provider:

View File

@@ -5,7 +5,6 @@ from typing import List
from .. import Provider, Personality
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
@@ -19,16 +18,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
default_persona: Personality = None,
default_persona: Personality | None = None,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
@@ -72,8 +67,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
@@ -166,6 +164,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")

View File

@@ -1,10 +1,9 @@
import astrbot.core.message.components as Comp
import os
from typing import List
from .. import Provider, Personality
from .. import Provider
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url, download_file
@@ -17,17 +16,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class ProviderDify(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
default_persona: Personality = None,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
self.api_key = provider_config.get("dify_api_key", "")
@@ -61,13 +56,18 @@ class ProviderDify(Provider):
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
image_urls = []
result = ""
session_id = session_id or kwargs.get("user") or "unknown" # 1734
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
@@ -100,6 +100,7 @@ class ProviderDify(Provider):
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
try:
match self.api_type:
@@ -199,6 +200,7 @@ class ProviderDify(Provider):
contexts=...,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")

View File

@@ -0,0 +1,63 @@
from google import genai
from google.genai import types
from google.genai.errors import APIError
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
from ..entities import ProviderType
@register_provider_adapter(
"gemini_embedding",
"Google Gemini Embedding 提供商适配器",
provider_type=ProviderType.EMBEDDING,
)
class GeminiEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key")
api_base: str = provider_config.get("embedding_api_base", None)
timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
if api_base:
if api_base.endswith("/"):
api_base = api_base[:-1]
http_options.base_url = api_base
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
self.model = provider_config.get(
"embedding_model", "gemini-embedding-exp-03-07"
)
self.dimension = provider_config.get("embedding_dimensions", 768)
async def get_embedding(self, text: str) -> list[float]:
"""
获取文本的嵌入
"""
try:
result = await self.client.models.embed_content(
model=self.model, contents=text
)
return result.embeddings[0].values
except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
批量获取文本的嵌入
"""
try:
result = await self.client.models.embed_content(
model=self.model, contents=texts
)
return [embedding.values for embedding in result.embeddings]
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension

View File

@@ -12,10 +12,9 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Personality, Provider
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.utils.io import download_image_by_url
@@ -52,17 +51,13 @@ class ProviderGoogleGenAI(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona: Personality = None,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
self.api_keys: list = provider_config.get("key", [])
@@ -141,24 +136,66 @@ class ProviderGoogleGenAI(Provider):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"]
tool_list = None
tool_list = []
model_name = self.get_model()
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False)
url_context = self.provider_config.get("gm_url_context", False)
if "gemini-2.5" in model_name:
if native_coderunner:
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
if native_search:
logger.warning("已启用代码执行工具搜索工具将被忽略")
if tools:
logger.warning("已启用代码执行工具,函数工具将被忽略")
logger.warning("代码执行工具搜索工具互斥,已忽略搜索工具")
if url_context:
logger.warning(
"代码执行工具与URL上下文工具互斥已忽略URL上下文工具"
)
else:
if native_search:
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
if url_context:
if hasattr(types, "UrlContext"):
tool_list.append(types.Tool(url_context=types.UrlContext()))
else:
logger.warning(
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
)
elif "gemini-2.0-lite" in model_name:
if native_coderunner or native_search or url_context:
logger.warning(
"gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文将忽略这些设置"
)
tool_list = None
else:
if native_coderunner:
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
if native_search:
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
elif native_search:
tool_list = [types.Tool(google_search=types.GoogleSearch())]
if tools:
logger.warning("已启用搜索工具,函数工具将被忽略")
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
if url_context and not native_coderunner:
if hasattr(types, "UrlContext"):
tool_list.append(types.Tool(url_context=types.UrlContext()))
else:
logger.warning(
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
)
if not tool_list:
tool_list = None
if tools and tool_list:
logger.warning("已启用原生工具,函数工具将被忽略")
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
@@ -433,6 +470,10 @@ class ProviderGoogleGenAI(Provider):
raise
continue
# Accumulate the complete response text for the final response
accumulated_text = ""
final_response = None
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)
@@ -444,32 +485,47 @@ class ProviderGoogleGenAI(Provider):
chunk, llm_response
)
yield llm_response
break
return
if chunk.text:
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
yield llm_response
if chunk.candidates[0].finish_reason:
llm_response = LLMResponse("assistant", is_chunk=False)
if not chunk.candidates[0].content.parts:
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
else:
llm_response.result_chain = self._process_content_parts(
chunk, llm_response
# Process the final chunk for potential tool calls or other content
if chunk.candidates[0].content.parts:
final_response = LLMResponse("assistant", is_chunk=False)
final_response.result_chain = self._process_content_parts(
chunk, final_response
)
yield llm_response
break
# Yield final complete response with accumulated text
if not final_response:
final_response = LLMResponse("assistant", is_chunk=False)
# Set the complete accumulated text in the final response
if accumulated_text:
final_response.result_chain = MessageChain(
chain=[Comp.Plain(accumulated_text)]
)
elif not final_response.result_chain:
# If no text was accumulated and no final response was set, provide empty space
final_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
yield final_response
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -485,10 +541,14 @@ class ProviderGoogleGenAI(Provider):
# tool calls result
if tool_calls_result:
if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -505,13 +565,14 @@ class ProviderGoogleGenAI(Provider):
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: str = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
prompt,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
@@ -527,10 +588,14 @@ class ProviderGoogleGenAI(Provider):
# tool calls result
if tool_calls_result:
if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -590,7 +655,10 @@ class ProviderGoogleGenAI(Provider):
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
{
"type": "image_url",
"image_url": {"url": image_data},
}
)
return user_content
else:

View File

@@ -0,0 +1,79 @@
import os
import uuid
import wave
from google import genai
from google.genai import types
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
@register_provider_adapter(
"gemini_tts", "Gemini TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
)
class ProviderGeminiTTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
api_key: str = provider_config.get("gemini_tts_api_key", "")
api_base: str | None = provider_config.get("gemini_tts_api_base")
timeout: int = int(provider_config.get("gemini_tts_timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
if api_base:
if api_base.endswith("/"):
api_base = api_base[:-1]
http_options.base_url = api_base
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
self.model: str = provider_config.get(
"gemini_tts_model", "gemini-2.5-flash-preview-tts"
)
self.prefix: str | None = provider_config.get(
"gemini_tts_prefix",
)
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
prompt = f"{self.prefix}: {text}" if self.prefix else text
response = await self.client.models.generate_content(
model=self.model,
contents=prompt,
config=types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name=self.voice_name,
)
)
),
),
)
# 不想看类型检查报错
if (
not response.candidates
or not response.candidates[0].content
or not response.candidates[0].content.parts
or not response.candidates[0].content.parts[0].inline_data
or not response.candidates[0].content.parts[0].inline_data.data
):
raise Exception("No audio content returned from Gemini TTS API.")
with wave.open(path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(24000)
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
return path

View File

@@ -0,0 +1,148 @@
import asyncio
import os
import uuid
import aiohttp
from ..provider import TTSProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@register_provider_adapter(
provider_type_name="gsv_tts_selfhost",
desc="GPT-SoVITS TTS(本地加载)",
provider_type=ProviderType.TEXT_TO_SPEECH,
)
class ProviderGSVTTS(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
"/"
)
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
# TTS 请求的默认参数移除前缀gsv_
self.default_params: dict = {
key.removeprefix("gsv_"): str(value).lower()
for key, value in provider_config.get("gsv_default_parms", {}).items()
}
self.timeout = provider_config.get("timeout", 60)
self._session: aiohttp.ClientSession | None = None
async def initialize(self):
"""异步初始化:在 ProviderManager 中被调用"""
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
)
try:
await self._set_model_weights()
logger.info("[GSV TTS] 初始化完成")
except Exception as e:
logger.error(f"[GSV TTS] 初始化失败:{e}")
raise
def get_session(self) -> aiohttp.ClientSession:
if not self._session or self._session.closed:
raise RuntimeError(
"[GSV TTS] Provider HTTP session is not ready or closed."
)
return self._session
async def _make_request(
self, endpoint: str, params=None, retries: int = 3
) -> bytes | None:
"""发起请求"""
for attempt in range(retries):
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
try:
async with self.get_session().get(endpoint, params=params) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
)
return await response.read()
except Exception as e:
if attempt < retries - 1:
logger.warning(
f"[GSV TTS] 请求 {endpoint}{attempt + 1} 次失败:{e},重试中..."
)
await asyncio.sleep(1)
else:
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
raise
async def _set_model_weights(self):
"""设置模型路径"""
try:
if self.gpt_weights_path:
await self._make_request(
f"{self.api_base}/set_gpt_weights",
{"weights_path": self.gpt_weights_path},
)
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
else:
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
if self.sovits_weights_path:
await self._make_request(
f"{self.api_base}/set_sovits_weights",
{"weights_path": self.sovits_weights_path},
)
logger.info(
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
)
else:
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
except aiohttp.ClientError as e:
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
except Exception as e:
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
async def get_audio(self, text: str) -> str:
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
if not text.strip():
raise ValueError("[GSV TTS] TTS 文本不能为空")
endpoint = f"{self.api_base}/tts"
params = self.build_synthesis_params(text)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
result = await self._make_request(endpoint, params)
if isinstance(result, bytes):
with open(path, "wb") as f:
f.write(result)
return path
else:
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
def build_synthesis_params(self, text: str) -> dict:
"""
构建语音合成所需的参数字典。
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
"""
params = self.default_params.copy()
params["text"] = text
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
return params
async def terminate(self):
"""终止释放资源:在 ProviderManager 中被调用"""
if self._session and not self._session.closed:
await self._session.close()
logger.info("[GSV TTS] Session 已关闭")

View File

@@ -1,132 +0,0 @@
import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
@register_provider_adapter(
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
)
class LLMTunerModelLoader(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
provider_config["adapter_model_path"]
):
raise FileNotFoundError("模型文件路径不存在。")
self.base_model_path = provider_config["base_model_path"]
self.adapter_model_path = provider_config["adapter_model_path"]
self.model = ChatModel(
{
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config["llmtuner_template"],
"finetuning_type": provider_config["finetuning_type"],
"quantization_bit": provider_config["quantization_bit"],
}
)
self.set_model(
os.path.basename(self.base_model_path)
+ "_"
+ os.path.basename(self.adapter_model_path)
)
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""
组装上下文。
"""
return {"role": "user", "content": text}
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = [],
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
system_prompt = ""
new_record = {"role": "user", "content": prompt}
query_context = [*contexts, new_record]
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(query_context):
if context["role"] == "system":
system_idxs.append(idx)
if "_no_save" in context:
del context["_no_save"]
for idx in reversed(system_idxs):
system_prompt += " " + query_context.pop(idx)["content"]
conf = {
"messages": query_context,
"system": system_prompt,
}
if func_tool:
tool_list = func_tool.get_func_desc_openai_style()
if tool_list:
conf["tools"] = tool_list
responses = await self.model.achat(**conf)
llm_response = LLMResponse("assistant", responses[-1].response_text)
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def get_current_key(self):
return "none"
async def set_key(self, key):
pass
async def get_models(self):
return [self.get_model()]

View File

@@ -0,0 +1,43 @@
from openai import AsyncOpenAI
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
from ..entities import ProviderType
@register_provider_adapter(
"openai_embedding",
"OpenAI API Embedding 提供商适配器",
provider_type=ProviderType.EMBEDDING,
)
class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
base_url=provider_config.get(
"embedding_api_base", "https://api.openai.com/v1"
),
timeout=int(provider_config.get("timeout", 20)),
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
self.dimension = provider_config.get("embedding_dimensions", 1024)
async def get_embedding(self, text: str) -> list[float]:
"""
获取文本的嵌入
"""
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
批量获取文本的嵌入
"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension

View File

@@ -9,14 +9,12 @@ import astrbot.core.message.components as Comp
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List, AsyncGenerator
@@ -30,17 +28,13 @@ from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
class ProviderOpenAIOfficial(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona: Personality = None,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
self.chosen_api_key = None
@@ -87,6 +81,17 @@ class ProviderOpenAIOfficial(Provider):
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
# Check if we need to add googleSearch function for Gemini(OpenAI Compatible)
if (
self.provider_config.get("enable_google_search", False)
and self.provider_config.get("api_base", "").find(
"generativelanguage.googleapis.com"
)
!= -1
):
# Add googleSearch function as alias to web_search
await self._add_google_search_tool(tools)
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
@@ -105,6 +110,11 @@ class ProviderOpenAIOfficial(Provider):
for key in to_del:
del payloads[key]
# 针对 qwen3 模型的特殊处理:非流式调用必须设置 enable_thinking=false
model = payloads.get("model", "")
if "qwen3" in model.lower():
extra_body["enable_thinking"] = False
completion = await self.client.chat.completions.create(
**payloads, stream=False, extra_body=extra_body
)
@@ -125,6 +135,17 @@ class ProviderOpenAIOfficial(Provider):
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API逐步返回结果"""
if tools:
# Check if we need to add googleSearch function for Gemini(OpenAI Compatible)
if (
self.provider_config.get("enable_google_search", False)
and self.provider_config.get("api_base", "").find(
"generativelanguage.googleapis.com"
)
!= -1
):
# Add googleSearch function as alias to web_search
await self._add_google_search_tool(tools)
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
@@ -182,7 +203,7 @@ class ProviderOpenAIOfficial(Provider):
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
if choice.message.content:
if choice.message.content is not None:
# text completion
completion_text = str(choice.message.content).strip()
llm_response.result_chain = MessageChain().message(completion_text)
@@ -193,9 +214,16 @@ class ProviderOpenAIOfficial(Provider):
func_name_ls = []
tool_call_ids = []
for tool_call in choice.message.tool_calls:
if isinstance(tool_call, str):
# workaround for #1359
tool_call = json.loads(tool_call)
for tool in tools.func_list:
if tool.name == tool_call.function.name:
# workaround for #1454
if isinstance(tool_call.function.arguments, str):
args = json.loads(tool_call.function.arguments)
else:
args = tool_call.function.arguments
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
@@ -209,7 +237,7 @@ class ProviderOpenAIOfficial(Provider):
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
)
if not llm_response.completion_text and not llm_response.tools_call_args:
if llm_response.completion_text is None and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
@@ -220,12 +248,11 @@ class ProviderOpenAIOfficial(Provider):
async def _prepare_chat_payload(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: list=None,
system_prompt: str=None,
tool_calls_result: ToolCallsResult=None,
image_urls: list[str] | None = None,
contexts: list | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -242,14 +269,18 @@ class ProviderOpenAIOfficial(Provider):
# tool calls result
if tool_calls_result:
if isinstance(tool_calls_result, ToolCallsResult):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
return payloads, context_query, func_tool
return payloads, context_query
async def _handle_api_error(
self,
@@ -340,22 +371,22 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat(
self,
prompt,
session_id = None,
image_urls = None,
func_tool = None,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
payloads, context_query, func_tool = await self._prepare_chat_payload(
payloads, context_query = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
model=model,
**kwargs,
)
@@ -415,17 +446,17 @@ class ProviderOpenAIOfficial(Provider):
contexts=[],
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
payloads, context_query, func_tool = await self._prepare_chat_payload(
payloads, context_query = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
model=model,
**kwargs,
)
@@ -481,13 +512,8 @@ class ProviderOpenAIOfficial(Provider):
"""
new_contexts = []
flag = False
for context in contexts:
if flag:
flag = False # 删除 image 后下一条LLM 响应)也要删除
continue
if isinstance(context["content"], list):
flag = True
if "content" in context and isinstance(context["content"], list):
# continue
new_content = []
for item in context["content"]:
@@ -530,7 +556,10 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
{
"type": "image_url",
"image_url": {"url": image_data},
}
)
return user_content
else:
@@ -546,3 +575,35 @@ class ProviderOpenAIOfficial(Provider):
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
async def _add_google_search_tool(self, tools: FuncCall) -> None:
"""Add googleSearch function as an alias to web_search for Gemini(OpenAI Compatible)"""
# Check if googleSearch is already added
for func in tools.func_list:
if func.name == "googleSearch":
return
# Check if web_search exists
web_search_func = None
for func in tools.func_list:
if func.name == "web_search":
web_search_func = func
break
if web_search_func is None:
# If web_search is not available, don't add googleSearch
return
# Add googleSearch as an alias to web_search with English description
tools.add_func(
name="googleSearch",
func_args=[
{
"type": "string",
"name": "query",
"description": "The most relevant search keywords for the user's question, used to search on Google.",
}
],
desc="Search the internet to answer user questions using Google search. Call this tool when users need to search the web for real-time information.",
handler=web_search_func.handler,
)

View File

@@ -5,12 +5,12 @@ import os
import traceback
import asyncio
import aiohttp
import requests
from ..provider import TTSProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot import logger
@register_provider_adapter(
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
)
@@ -22,7 +22,9 @@ class ProviderVolcengineTTS(TTSProvider):
self.cluster = provider_config.get("volcengine_cluster", "")
self.voice_type = provider_config.get("volcengine_voice_type", "")
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
self.api_base = provider_config.get(
"api_base", "https://openspeech.bytedance.com/api/v1/tts"
)
self.timeout = provider_config.get("timeout", 20)
def _build_request_payload(self, text: str) -> dict:
@@ -30,11 +32,9 @@ class ProviderVolcengineTTS(TTSProvider):
"app": {
"appid": self.appid,
"token": self.api_key,
"cluster": self.cluster
},
"user": {
"uid": str(uuid.uuid4())
"cluster": self.cluster,
},
"user": {"uid": str(uuid.uuid4())},
"audio": {
"voice_type": self.voice_type,
"encoding": "mp3",
@@ -48,15 +48,15 @@ class ProviderVolcengineTTS(TTSProvider):
"text_type": "plain",
"operation": "query",
"with_frontend": 1,
"frontend_type": "unitTson"
}
"frontend_type": "unitTson",
},
}
async def get_audio(self, text: str) -> str:
"""异步方法获取语音文件路径"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer; {self.api_key}"
"Authorization": f"Bearer; {self.api_key}",
}
payload = self._build_request_payload(text)
@@ -71,7 +71,7 @@ class ProviderVolcengineTTS(TTSProvider):
self.api_base,
data=json.dumps(payload),
headers=headers,
timeout=self.timeout
timeout=self.timeout,
) as response:
logger.debug(f"响应状态码: {response.status}")
@@ -90,8 +90,7 @@ class ProviderVolcengineTTS(TTSProvider):
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: open(file_path, "wb").write(audio_data)
None, lambda: open(file_path, "wb").write(audio_data)
)
return file_path
@@ -99,7 +98,9 @@ class ProviderVolcengineTTS(TTSProvider):
error_msg = resp_data.get("message", "未知错误")
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
else:
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
raise Exception(
f"火山引擎 TTS API 请求失败: {response.status}, {response_text}"
)
except Exception as e:
error_details = traceback.format_exc()

View File

@@ -1,4 +1,3 @@
from astrbot.core.db import BaseDatabase
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
@@ -13,15 +12,11 @@ class ProviderZhipu(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
db_helper,
persistant_history,
default_persona,
)
@@ -31,17 +26,20 @@ class ProviderZhipu(ProviderOpenAIOfficial):
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts=[],
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
model = self.get_model()
model = model or self.get_model()
# glm-4v-flash 只支持一张图片
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")

View File

@@ -1,20 +0,0 @@
from typing import List
from openai import AsyncOpenAI
class SimpleOpenAIEmbedding:
def __init__(
self,
model,
api_key,
api_base=None,
) -> None:
self.client = AsyncOpenAI(api_key=api_key, base_url=api_base)
self.model = model
async def get_embedding(self, text) -> List[float]:
"""
获取文本的嵌入
"""
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding

View File

@@ -1,95 +0,0 @@
import os
from typing import List, Dict
from astrbot.core import logger
from .store import Store
from astrbot.core.config import AstrBotConfig
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class KnowledgeDBManager:
def __init__(self, astrbot_config: AstrBotConfig) -> None:
self.db_path = os.path.join(get_astrbot_data_path(), "knowledge_db")
self.config = astrbot_config.get("knowledge_db", {})
self.astrbot_config = astrbot_config
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
self.store_insts: Dict[str, Store] = {}
for name, cfg in self.config.items():
if cfg["strategy"] == "embedding":
logger.info(f"加载 Chroma Vector Store{name}")
try:
from .store.chroma_db import ChromaVectorStore
except ImportError as ie:
logger.error(f"{ie} 可能未安装 chromadb 库。")
continue
self.store_insts[name] = ChromaVectorStore(
name, cfg["embedding_config"]
)
else:
logger.error(f"不支持的策略:{cfg['strategy']}")
async def list_knowledge_db(self) -> List[str]:
return [
f
for f in os.listdir(self.db_path)
if os.path.isfile(os.path.join(self.db_path, f))
]
async def create_knowledge_db(self, name: str, config: Dict):
"""
config 格式:
```
{
"strategy": "embedding", # 目前只支持 embedding
"chunk_method": {
"strategy": "fixed",
"chunk_size": 100,
"overlap_size": 10
},
"embedding_config": {
"strategy": "openai",
"base_url": "",
"model": "",
"api_key": ""
}
}
```
"""
if name in self.config:
raise ValueError(f"知识库已存在:{name}")
self.config[name] = config
self.astrbot_config["knowledge_db"] = self.config
self.astrbot_config.save_config()
async def insert_record(self, name: str, text: str):
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
ret = []
match self.config[name]["chunk_method"]["strategy"]:
case "fixed":
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
case _:
pass
for chunk in ret:
await self.store_insts[name].save(chunk)
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
inst = self.store_insts[name]
return await inst.query(query, top_n)
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - chunk_overlap
return chunks

View File

@@ -1,9 +0,0 @@
from typing import List
class Store:
async def save(self, text: str):
pass
async def query(self, query: str, top_n: int = 3) -> List[str]:
pass

View File

@@ -1,44 +0,0 @@
import chromadb
import uuid
from typing import List, Dict
from astrbot.api import logger
from ..embedding.openai_source import SimpleOpenAIEmbedding
from . import Store
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class ChromaVectorStore(Store):
def __init__(self, name: str, embedding_cfg: Dict) -> None:
import os
self.chroma_client = chromadb.PersistentClient(
path=os.path.join(get_astrbot_data_path(), "long_term_memory_chroma.db")
)
self.collection = self.chroma_client.get_or_create_collection(name=name)
self.embedding = None
if embedding_cfg["strategy"] == "openai":
self.embedding = SimpleOpenAIEmbedding(
model=embedding_cfg["model"],
api_key=embedding_cfg["api_key"],
api_base=embedding_cfg.get("base_url", None),
)
async def save(self, text: str, metadata: Dict = None):
logger.debug(f"Saving text: {text}")
embedding = await self.embedding.get_embedding(text)
self.collection.upsert(
documents=text,
metadatas=metadata,
ids=str(uuid.uuid4()),
embeddings=embedding,
)
async def query(
self, query: str, top_n=3, metadata_filter: Dict = None
) -> List[str]:
embedding = await self.embedding.get_embedding(query)
results = self.collection.query(
query_embeddings=embedding, n_results=top_n, where=metadata_filter
)
return results["documents"][0]

View File

@@ -1,4 +1,4 @@
from .star import StarMetadata
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
from .context import Context
from astrbot.core.provider import Provider
@@ -10,23 +10,48 @@ from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin):
"""所有插件Star的父类所有插件都应该继承于这个类"""
def __init__(self, context: Context):
def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context)
self.context = context
async def text_to_image(self, text: str, return_url=True) -> str:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not star_map.get(cls.__module__):
metadata = StarMetadata(
star_cls_type=cls,
module_path=cls.__module__,
)
star_map[cls.__module__] = metadata
star_registry.append(metadata)
else:
star_map[cls.__module__].star_cls_type = cls
star_map[cls.__module__].module_path = cls.__module__
@staticmethod
async def text_to_image(text: str, return_url=True) -> str:
"""将文本转换为图片"""
return await html_renderer.render_t2i(text, return_url=return_url)
async def html_render(self, tmpl: str, data: dict, return_url=True) -> str:
@staticmethod
async def html_render(
tmpl: str, data: dict, return_url=True, options: dict = None
) -> str:
"""渲染 HTML"""
return await html_renderer.render_custom_template(
tmpl, data, return_url=return_url
tmpl, data, return_url=return_url, options=options
)
async def initialize(self):
"""当插件被激活时会调用这个方法"""
pass
async def terminate(self):
"""当插件被禁用、重载插件时会调用这个方法"""
pass
def __del__(self):
"""[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
pass
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]

View File

@@ -2,7 +2,13 @@ from asyncio import Queue
from typing import List, Union
from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
from astrbot.core.provider.provider import (
Provider,
TTSProvider,
STTProvider,
EmbeddingProvider,
)
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.func_tool_manager import FuncCall
@@ -16,7 +22,6 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
@@ -42,6 +47,8 @@ class Context:
platform_manager: PlatformManager = None
registered_web_apis: list = []
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
@@ -54,14 +61,12 @@ class Context:
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
knowledge_db_manager: KnowledgeDBManager = None,
):
self._event_queue = event_queue
self._config = config
self._db = db
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
self.conversation_manager = conversation_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
@@ -126,11 +131,8 @@ class Context:
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
"""通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
def get_all_providers(self) -> List[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
@@ -144,24 +146,50 @@ class Context:
"""获取所有用于 STT 任务的 Provider。"""
return self.provider_manager.stt_provider_insts
def get_using_provider(self) -> Provider:
"""
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
通过 /provider 指令切换。
def get_using_provider(self, umo: str = None) -> Provider:
"""
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
"""
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_provider_inst
def get_using_tts_provider(self) -> TTSProvider:
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
"""
获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_tts_provider_inst
def get_using_stt_provider(self) -> STTProvider:
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
"""
获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_stt_provider_inst
def get_config(self) -> AstrBotConfig:
@@ -301,3 +329,12 @@ class Context:
注册一个异步任务。
"""
self._register_tasks.append(task)
def register_web_api(
self, route: str, view_handler: Awaitable, methods: list, desc: str
):
for idx, api in enumerate(self.registered_web_apis):
if api[0] == route and methods == api[2]:
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
return
self.registered_web_apis.append((route, view_handler, methods, desc))

Some files were not shown because too many files have changed in this diff Show More