Compare commits

...

904 Commits

Author SHA1 Message Date
Soulter
9564166297 perf: knowledge base displays console when installing 2025-05-31 11:52:24 +08:00
Soulter
f5cf3c3c8e Merge pull request #1691 from AstrBotDevs/perf-pip-async
Feature: 将插件依赖检查和 pip 安装方法改为异步,以提高性能和响应速度
2025-05-31 11:51:39 +08:00
Soulter
18f919fb6b perf: pip_main wrapped in asyncio.to_thread
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-05-31 11:47:29 +08:00
Soulter
0924835253 feat: 将插件依赖检查和 pip 安装方法改为异步,以提高性能和响应速度 2025-05-31 11:44:58 +08:00
Soulter
20d2e5c578 perf: 优化日志流发送频率,防止积压超过 buffer size 导致前端显示异常 2025-05-31 11:25:51 +08:00
Soulter
907801605c 📦 release: v3.5.13 2025-05-31 11:02:56 +08:00
Soulter
93bc684e8c feat: 添加旧版本提供商类型映射以兼容性支持 2025-05-31 11:00:59 +08:00
Soulter
a76c98d57e Merge pull request #1685 from RC-CHN/master
Feature: 添加测试文本生成供应商可用功能
2025-05-31 10:59:46 +08:00
Soulter
d937a800d0 fix: provider name 2025-05-31 10:46:35 +08:00
Soulter
d16f3a227f Merge branch 'master' into master 2025-05-31 10:46:15 +08:00
Soulter
80c9a3eeda style: code style 2025-05-31 09:25:18 +08:00
Soulter
e68173b451 feat: knowledge-base 2025-05-30 23:18:48 +08:00
Soulter
40c27d87f5 feat: knowledge-base 2025-05-30 23:18:19 +08:00
Soulter
3c13b5049d feat: 支持知识库的分片、重叠设置等 2025-05-30 23:00:37 +08:00
Soulter
8288d5e51f feat: embedding provider 2025-05-30 18:07:52 +08:00
RC-CHN
4ffbb18ab4 Merge branch 'AstrBotDevs:master' into master 2025-05-30 15:12:33 +08:00
Ruochen
b27271b7a3 feat:添加测试文本生成供应商可用功能 2025-05-30 15:10:15 +08:00
Soulter
ebb6665f64 feat: add open_config parameter handling and configuration button in KnowledgeBase 2025-05-30 14:30:04 +08:00
Soulter
e4e5731ffd 📦 release: v3.5.13 2025-05-30 13:30:23 +08:00
Soulter
2ab5810f13 perf: improve transaction performance in vector db 2025-05-30 12:59:26 +08:00
Soulter
af934c5d09 fix: correct dimension typo and enhance API registration logic 2025-05-30 11:42:39 +08:00
Soulter
1e0cf7c112 fix: update ExtensionCard actions and add readme link functionality 2025-05-30 10:50:54 +08:00
Soulter
46859c93c9 perf: improve WebUI 2025-05-30 10:45:05 +08:00
Soulter
1641549016 perf: improve WebUI 2025-05-30 10:36:48 +08:00
鸦羽
716a5dbb8a chore: add nh3 to requirements.txt 2025-05-30 10:35:48 +08:00
鸦羽
af98cb11c5 fix: handle missing nh3 library in plugin.py 2025-05-30 10:35:48 +08:00
Soulter
9a4c2cf341 fix: downgrade faiss-cpu dependency to version 1.10.0 2025-05-30 10:21:31 +08:00
Soulter
2bc3bcd102 fix: handle missing nh3 library gracefully for README cleaning 2025-05-30 10:17:33 +08:00
Soulter
d6c663f79d fix: do not display change password dialog in demo mode 2025-05-30 10:09:09 +08:00
Soulter
2b4ee13b5e Merge pull request #1672 from Kwicxy/master
Feat: 暗黑主题功能初步实现
2025-05-29 23:41:10 +08:00
kwicxy
6959f86632 feat: Using localStorage to remember user's theme setting. 2025-05-29 22:46:02 +08:00
Raven95676
537d373e10 fix: Fix potential XSS risk in plugin README content 2025-05-29 22:35:24 +08:00
Soulter
cceadf222c Merge pull request #1676 from AstrBotDevs/fix-chat-get-file-bug
Fix: fixed a potential vulnerability in `/api/chat/get_file` endpoint.
2025-05-29 21:41:55 +08:00
Soulter
cf5a4af623 chore: remove duplicated auth header 2025-05-29 21:19:39 +08:00
Raven95676
39aea11c22 perf: enhance file access security in get_file method
Co-authored-by: anka-afk <1350989414@qq.com>
2025-05-29 21:03:51 +08:00
Raven95676
c2f1227700 fix: add authorization header to file download request in ChatPage.vue 2025-05-29 19:57:11 +08:00
Soulter
900f14d37c 🐛 fix: fixed a potential vulnerability in /api/chat/get_file endpoint.
I have fixed a potential vulnerability in the `/api/chat/get_file` endpoint that could allow unauthorized access to files by ensuring the request has a jwt token.
2025-05-29 19:17:31 +08:00
kwicxy
598249b1d6 Merge remote-tracking branch 'origin/master' 2025-05-29 18:26:53 +08:00
Richard X.
7ed15bdf04 Merge branch 'AstrBotDevs:master' into master 2025-05-29 18:17:39 +08:00
Raven95676
2fc0ec0f72 fix: update route 2025-05-29 17:28:33 +08:00
kwicxy
5e9c2a669b fix: Various bug fixes and improvements 2025-05-29 16:41:03 +08:00
Soulter
b310521884 📦 release: v3.5.12 2025-05-29 15:55:25 +08:00
Soulter
288945bf7e chore: aiosqlite to requirements.txt 2025-05-29 15:48:21 +08:00
Soulter
4fc07cff36 📦 release: v3.5.12 2025-05-29 15:46:40 +08:00
kwicxy
b884fe0e86 fix: Various bug fixes 2025-05-29 09:31:29 +08:00
kwicxy
855858c236 fix: Changed default theme to PurpleTheme 2025-05-29 09:31:15 +08:00
kwicxy
c11a2a5419 feat: Login page darkened 2025-05-29 09:00:27 +08:00
kwicxy
773a6572af feat: WebUI Dark Appearance 2025-05-29 01:43:21 +08:00
kwicxy
88ad373c9b 深色主题切换功能初步实现 2025-05-29 01:28:45 +08:00
Soulter
51666464b9 Merge pull request #1667 from AstrBotDevs/fix-priority
Fix: plugin priority was not properly applied
2025-05-28 15:34:50 +08:00
Soulter
5af9cf2f52 Merge pull request #1668 from AstrBotDevs/refactor-segment
Refactor: 重构转发节点等消息段的 toDict 相关逻辑
2025-05-28 15:33:32 +08:00
Soulter
12c4ae4b10 perf: to_dict in the base class 2025-05-28 03:26:42 -04:00
Soulter
4e1bef414a perf: empty array 2025-05-28 03:25:19 -04:00
Soulter
e896c18644 perf: video 2025-05-28 15:12:21 +08:00
Soulter
c852685e74 fix: typeerror 2025-05-28 01:18:45 -04:00
Soulter
1e99797df8 refactor: improve message segment handle 2025-05-28 12:53:00 +08:00
Soulter
52a4c986a8 fix: update star_handlers_registry iteration in TelegramPlatformAdapter 2025-05-28 00:31:04 +08:00
Soulter
c501728204 fix: plugin priority
fixes: #1662
2025-05-28 00:23:02 +08:00
Soulter
6b067fa6a7 Merge pull request #1665 from Raven95676/master
fix(telegram): 支持长消息分段发送并优化消息编辑逻辑
2025-05-27 23:39:14 +08:00
Soulter
a1cd5c53a9 chore: add comments 2025-05-27 23:38:35 +08:00
Soulter
a46d487e03 Merge pull request #1644 from RC-CHN/master
fix:为llm和model和provider指令添加了管理员权限检查
2025-05-27 23:25:40 +08:00
Raven95676
3deb6d3ab3 fix: clean code 2025-05-27 20:52:40 +08:00
Raven95676
af34cdd5d2 fix(telegram): 支持长消息分段发送并优化消息编辑逻辑 2025-05-27 20:15:16 +08:00
Soulter
6e1393235a 🐛 fix: provider command error 2025-05-27 17:20:57 +08:00
Soulter
343e0b54b9 feat: MCP supports Streamable HTTP transport method
fixes: #1637 #1342
2025-05-27 15:39:02 +08:00
Soulter
ecb70cb6f7 feat: add support for custom headers in SSE client configuration
fixes: #1659
2025-05-27 15:05:42 +08:00
Soulter
ca50618af6 perf: load providers when llm config is off and rebooting astrbot
fixes: #1466
2025-05-27 15:01:58 +08:00
Soulter
29c07ba83e 🐛 fix: function tools argument type issue
fixes: #1454
2025-05-27 13:54:16 +08:00
Ruochen
45fbb83a9f fix:为llm和model和provider指令添加了管理员权限检查 2025-05-25 00:24:20 +08:00
Soulter
ae7ba2df25 Merge pull request #1553 from Raven95676/Feature/use-file-service
Feature: T2I、TTS使用文件服务
2025-05-23 17:10:38 +08:00
Soulter
c3ef57cc32 Merge pull request #1588 from Zhenyi-Wang/feat/extend-wechatpadpro-for-timetask
feat: wechatpadpro对接获取联系人信息的2个接口
2025-05-23 17:02:54 +08:00
Soulter
7bb4ca5a14 perf: code quality 2025-05-23 17:01:57 +08:00
Soulter
063783d81d Merge pull request #1599 from HendricksJudy/master
Fix initialization bug and improve plugin utility
2025-05-23 16:58:25 +08:00
Soulter
42116c9b65 Merge pull request #1631 from AstrBotDevs/feat/alkaid
[WIP] Feature: 提供 AstrBot 后端服务插件接口、试验性嵌入式知识库(Alkaid)、移除不必要的包
2025-05-23 16:57:04 +08:00
Soulter
a36e11973d perf: code quality
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:56:09 +08:00
Soulter
5125568ea2 perf: 交换 if/else 表达式的分支以删除否定
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:49:08 +08:00
Soulter
0fa164e50d perf: 使用 HTML autocomplete 属性禁用浏览器自动填充
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:48:29 +08:00
Soulter
cf814e81ee chore: delete alkaid route 2025-05-23 16:41:33 +08:00
Soulter
43a45f18ce perf: knowledgebase delete 2025-05-23 15:50:10 +08:00
Soulter
ad51381063 perf: 动态路由注册 2025-05-23 15:18:16 +08:00
Soulter
0b0e4ce904 remove: vpet 2025-05-23 14:22:34 +08:00
Soulter
6a3e04d688 Merge remote-tracking branch 'origin/master' into feat/alkaid 2025-05-23 14:22:06 +08:00
Soulter
4107a17370 chore: add faiss and aiosqlite deps 2025-05-23 14:04:13 +08:00
Soulter
06b4d8f169 perf: vecdb similarity type 2025-05-23 13:45:00 +08:00
Soulter
1c0c820746 remove: loguru 2025-05-23 13:42:17 +08:00
Soulter
d061403a28 remove: loguru 2025-05-23 13:39:20 +08:00
Soulter
5c092321a6 feat: faiss vecdb implementation
remove: old knowledgedb deps
2025-05-23 13:16:24 +08:00
Soulter
bdd3f61c1f remove: old knowledge db impl and useless impls 2025-05-23 11:43:26 +08:00
Raven95676
8023557d6e feat: 强制修改默认密码 2025-05-22 18:30:29 +08:00
Raven95676
074b0ced7a perf: 移除冗余逻辑
经与@Soulter确认,metadata.yaml是必须有的文件,故在建议下删除
2025-05-22 18:21:41 +08:00
Soulter
3864b1ac9b Merge pull request #1620 from YOOkoishi/feat-add-volcengine-support
🐛 fix : 修改description,适配火山引擎基础的语音合成
2025-05-22 17:52:39 +08:00
YOO_koishi
6e9b43457d Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-22 08:09:59 +08:00
YOO_koishi
ca1aec8920 🐛 fix : 修改description,适配火山引擎基础的语音合成 2025-05-22 08:09:36 +08:00
Soulter
acac580862 feat: ltm and kb 2025-05-20 20:50:22 +08:00
Soulter
673e1b2980 remove: vpet 2025-05-20 15:03:40 +08:00
Soulter
f62157be72 📦 release: v3.5.11 2025-05-20 02:00:54 -04:00
Soulter
f894ecf3b6 Merge pull request #1592 from YOOkoishi/feat-add-volcengine-support
 feat: add volcengine support
2025-05-20 13:58:44 +08:00
Soulter
66dd4e28ad Merge pull request #1604 from Siztas/fix-refresh-device-when-login-WeChatPadPro
fix:修复了WeChatPadPro在重新登录时为新设备的问题,延长初始化Auth_Key有效期至365天
2025-05-20 13:57:40 +08:00
YOO_koishi
939dc1b0fb Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-20 13:52:03 +08:00
YOO_koishi
56bf5d38a1 🔧fix: 修改logger输出等级为debug级别 2025-05-20 13:51:11 +08:00
Soulter
d09b70b295 fix: 修复微信公众号(个人认证)下无法回复消息的问题 2025-05-20 01:38:13 -04:00
MiSeya
205180387a Fix:修复了WeChatPadPro在重新登录时为新设备的问题,延长初始化Auth_Key有效期至365天 2025-05-19 21:12:09 +08:00
HendricksJudy
39c8cfeda5 Merge pull request #2 from HendricksJudy/codex/fix-core-initialization-failure-handling-in-initialloader
Fix initialization bug and improve plugin utility
2025-05-19 01:43:22 -07:00
HendricksJudy
f38a329be5 Fix initialization and plugin download 2025-05-19 01:43:07 -07:00
YOO_koishi
a0cd069539 Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-19 16:17:43 +08:00
YOO_koishi
bf306a2f01 🩹fix: 修改添加logger函数,添加speed_ratio选项,为一些选项添加description 2025-05-19 16:16:25 +08:00
Soulter
c31f93a8d1 Merge pull request #1595 from HendricksJudy/master
Fix lint issues and highlight typos
2025-05-19 09:29:02 +08:00
HendricksJudy
4730ab6309 Merge pull request #1 from HendricksJudy/codex/find-bugs-or-typos
Fix lint issues and highlight typos
2025-05-18 02:31:17 -07:00
HendricksJudy
1ae78ca98c chore: fix lint issues 2025-05-18 02:30:31 -07:00
Soulter
d2379da478 chore: use d3 2025-05-18 16:43:47 +08:00
Soulter
0f64981b20 feat: alkaid long term memory graph visualize 2025-05-18 13:26:44 +08:00
YOO_koishi
0002e49bb5 Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-18 03:20:05 +08:00
YOO_koishi
db13a60274 feat: add-volcengine-tts-support 2025-05-18 03:18:36 +08:00
Soulter
db0f11a359 Merge pull request #1589 from Larch-C/master
🎈 perf: 优化了登录界面,解决了登录未自行跳转的问题
2025-05-17 21:40:14 +08:00
Soulter
ac7f43520b 🎈 perf: adjust login input padding style 2025-05-17 21:30:05 +08:00
Larch-C
f67b9f5f6e 🐞 fix: 解决了如果此前已经登录但未自行跳转的问题 2025-05-17 18:09:49 +08:00
Larch-C
c75156c4ce 🎈 perf: 优化了登录界面样式 2025-05-17 18:08:55 +08:00
Soulter
10270b5595 feat: alkaid framework and supports to customize webapi endpoint 2025-05-17 15:38:51 +08:00
Zhenyi Wang
f7458572ed feat: wechatpadpro对接获取联系人信息接口 2025-05-17 15:31:12 +08:00
Soulter
d57b7222b2 perf: 优化 WebUI About 页面、侧边栏和顶栏 2025-05-17 13:30:33 +08:00
Soulter
62e70a673a perf: 优化 Gemini 报错提示 2025-05-17 12:04:36 +08:00
Soulter
5e9eba6478 fix: extension market plugin card cannot apply installation 2025-05-16 22:43:38 -04:00
Soulter
cb02dfe1a4 perf: 优化超时时间 2025-05-16 20:00:14 +08:00
Soulter
b50739e1af perf: 优化登录超时时间 2025-05-16 19:33:37 +08:00
Soulter
8da1b0212d Update README.md 2025-05-16 18:46:26 +08:00
Soulter
ca1f2acb33 Merge pull request #1551 from GowayLee/master
Feature: 添加对 MiniMax TTS API的支持
2025-05-16 18:32:49 +08:00
Soulter
c15f966669 fix: 修复 minimax 相关问题 2025-05-16 18:32:08 +08:00
Soulter
7705b8781a 📦 release: v3.5.10 2025-05-16 17:50:56 +08:00
Soulter
b2502746f0 perf: QQ 下,屏蔽 QQ 管家的消息事件 2025-05-16 17:49:17 +08:00
Soulter
ab68094386 docs: update platform tutprial map 2025-05-16 17:33:57 +08:00
Soulter
bbec701223 Merge pull request #1569 from xiamuceer-j/master
适配一个个人微信适配器——wechatpadpro
2025-05-16 17:29:57 +08:00
Soulter
b29d14e600 perf: 优化适配器终止流程 2025-05-16 17:29:33 +08:00
Soulter
86e51c5cd1 perf: 改进 wechatpadpro 超时重连 2025-05-16 17:22:10 +08:00
Soulter
cb8267be3f feat: wechatpadpro 支持图片接收 2025-05-16 17:18:42 +08:00
xiamuceer
eaed43915c Merge remote-tracking branch 'origin/master' 2025-05-16 17:18:04 +08:00
xiamuceer
bd91fd2c38 Merge branch 'master' of https://github.com/xiamuceer-j/AstrBot 2025-05-16 17:17:51 +08:00
xiamuceer
1203b214cd Merge branch 'master' of https://github.com/xiamuceer-j/AstrBot 2025-05-16 17:05:16 +08:00
xiamuceer
c3fec15f11 update: 添加ws超时重连机制,避免过长时间收不到消息 2025-05-16 17:00:06 +08:00
Soulter
0545653494 feat: 支持轮询消息 2025-05-16 16:54:49 +08:00
Soulter
db2989bdb4 perf: guess private message username 2025-05-16 15:42:33 +08:00
xiamuceer
587bd00a19 update: 新增send_by_session方法,接受处理来自AstrBot核心的消息 2025-05-16 14:30:05 +08:00
Soulter
960ff438e8 🎈perf: 旧消息丢弃 2025-05-16 13:26:45 +08:00
Raven95676
98e7ea85d3 fix: 正确导入WeChatPadProAdapter 2025-05-16 12:39:14 +08:00
xiamuceer
2549e44710 fix: 移除错误引用 2025-05-16 12:26:54 +08:00
xiamuceer
4d32b563ca fix: 对auth_key授权码进行脱敏处理 2025-05-16 12:08:49 +08:00
xiamuceer
3a4b732977 fix: 修复@消息适配,并写明适配器 2025-05-16 11:52:54 +08:00
夏目侧耳
500909a28e Update astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py
Co-authored-by: 鸦羽 <Raven95676@gmail.com>
2025-05-16 11:47:52 +08:00
Soulter
07753eb25b Merge pull request #1561 from Raven95676/Fix/1554
fix(tts): record组件单独发送以保证兼容性
2025-05-16 11:10:45 +08:00
Soulter
c6eaf3d010 refactor: use aiohttp 2025-05-16 11:04:01 +08:00
Soulter
6723fe8271 🐛 fix: cannot save value when fullscreen editor mode 2025-05-16 10:37:30 +08:00
Raven95676
3348b70435 chore: add dependency 2025-05-16 10:30:29 +08:00
Soulter
35a8527c16 🎈 perf: update defaule value of minimax-timber-weight 2025-05-16 10:29:46 +08:00
Soulter
7afc475290 🐛 fix: value cannot displayed when fullscreen editior mode 2025-05-16 10:29:22 +08:00
Soulter
789bceaa3a Merge remote-tracking branch 'origin/master' into GowayLee/master 2025-05-16 10:23:30 +08:00
Soulter
abbc043969 Merge pull request #1575 from AstrBotDevs/feat-code-editor
Feature: WebUI 配置项支持代码编辑器模式
2025-05-16 10:22:16 +08:00
Soulter
654e5762f1 🐛 fix: 修复 VueMonacoEditor 的 v-model 绑定方式 2025-05-16 10:20:03 +08:00
Soulter
507c3e3629 feat: 配置项支持代码编辑器模式 2025-05-16 10:14:16 +08:00
Raven95676
991dfeb2f2 style: format code, disable redundant logs 2025-05-16 09:28:15 +08:00
夏目侧耳
26482fc2d3 Update astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-15 20:59:53 +08:00
夏目侧耳
e0ce6d9688 Update astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-15 20:57:22 +08:00
xiamuceer
946595216a 优化wechapadpro代码结构 2025-05-15 20:43:33 +08:00
anka
864b6bc56d fix: 🤠 修复指令后有@导致无法触发指令的问题 2025-05-15 20:00:46 +08:00
夏目侧耳
6ea5b7581f Update astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-15 19:12:42 +08:00
夏目侧耳
f70b8f0c10 Update astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-15 19:09:56 +08:00
夏目侧耳
1593bcb537 Update astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-15 17:50:29 +08:00
xiamuceer
bf7fc02c8d 适配一个个人微信适配器——wechatpadpro 2025-05-15 17:26:31 +08:00
Raven95676
143702b92b fix(tts): record组件单独发送以保证兼容性 2025-05-15 10:18:05 +08:00
Raven95676
c5ccc1a084 feat(Video): 增加视频消息组件的文件转换和注册功能 2025-05-15 09:50:27 +08:00
Soulter
2ecb52a9b2 Merge pull request #1529 from anka-afk/1446-bug-mcp
feat: 😽将At字段(非唤起)添加至message_str,修正message_str构造方式
2025-05-14 23:06:25 +08:00
YOO_koishi
6439917cbe Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-14 22:45:02 +08:00
YOO_koishi
d21c18f657 change defualt.py 2025-05-14 22:43:40 +08:00
Li Haoyuan
25ef0039e4 refactor: Optimize MiniMax TTS API Provider 2025-05-14 20:59:45 +08:00
Raven95676
e6981290bc perf: 优化 Record 对象的文件和 URL 字段赋值逻辑 2025-05-14 20:05:38 +08:00
Raven95676
75c3d8abbd feat(t2i): 为本地文本转图像功能添加文件服务支持 2025-05-14 19:28:23 +08:00
Raven95676
d88683f498 feat(tts): 增加使用文件服务提供 TTS 语音文件的功能 2025-05-14 19:28:23 +08:00
Raven95676
40b9aa3a4c style: format code 2025-05-14 19:15:13 +08:00
渡鸦95676
b6d1515d58 Merge pull request #1541 from Raven95676/fix/astrbot-reboot
fix: 回退至os.execl以兼容docker,改用双引号处理路径空格
2025-05-14 14:57:13 +08:00
Li Haoyuan
e01d4264e3 docs: Adjust MiniMax TTS timber_weights description 2025-05-14 14:40:25 +08:00
Li Haoyuan
2117b65487 feat: Support timber_weights for MiniMax TTS 2025-05-14 14:21:23 +08:00
Li Haoyuan
a7823b352f docs: Adjust MiniMax TTS configuration info 2025-05-14 13:09:09 +08:00
Li Haoyuan
c543b62a08 Merge branch 'AstrBotDevs:master' into master 2025-05-14 13:02:54 +08:00
Li Haoyuan
3923b87f08 feat: Add MiniMax TTS API provider 2025-05-14 13:02:31 +08:00
Soulter
b7ecdadb83 docs: update providers 2025-05-14 09:35:59 +08:00
Soulter
5ff121e1ed docs: PPIO 派欧云 2025-05-14 09:33:35 +08:00
Soulter
f486e5448f Merge pull request #1539 from Raven95676/Feature/ppio
feat: 接入PPIO派欧云
2025-05-14 09:07:38 +08:00
Raven95676
c5aae98558 fix: update reboot logic to handle executable paths correctly 2025-05-13 16:03:04 +08:00
Raven95676
6d8a3b9897 fix: 回退至os.execl以兼容docker,改用双引号处理路径空格 2025-05-13 10:18:11 +08:00
Raven95676
6d98780e19 feat: 接入PPIO派欧云 2025-05-12 18:22:02 +08:00
Raven95676
3ad2c46f3f perf: tg适配器同步aiocqhttp处理逻辑 2025-05-12 15:04:23 +08:00
Raven95676
a730cee7fd fix: at全体不加入message_str 2025-05-12 14:48:31 +08:00
anka
77c823c100 fix: 增加对全体成员的支持 2025-05-12 11:32:40 +08:00
anka
124f21c67a Merge remote-tracking branch 'origin/1446-bug-mcp' into 1446-bug-mcp 2025-05-12 11:24:09 +08:00
anka
e46cf20dd3 fix: 不再添加唤醒的@到message_str 2025-05-12 11:22:46 +08:00
Raven95676
4bef5e8313 fix: 避免message_str被覆盖 2025-05-12 00:21:48 +08:00
anka
22e93b0af4 Merge branch 'AstrBotDevs:master' into 1446-bug-mcp 2025-05-11 22:59:02 +08:00
anka
5aeca9662b feat: 对aiocqhttp中, At字段新增处理: 现在At字段同时也会被解析为文本信息(但消息链并没有修改, 只是在用于llm请求的文本中添加了At信息) 2025-05-11 22:57:50 +08:00
Raven95676
b996cf1f05 chore: update multiple dependencies 2025-05-11 22:16:16 +08:00
渡鸦95676
878a106877 fix changelog 2025-05-11 21:31:27 +08:00
Soulter
45d36f86fd fix: 优化限流逻辑,确保在达到限流阈值时正确处理请求 2025-05-11 21:22:14 +08:00
Soulter
b108ae403a docs: uvx 2025-05-11 20:31:46 +08:00
Soulter
887ed66768 docs: uvx 2025-05-11 20:30:30 +08:00
Soulter
dac840a887 📦release: v3.5.9 2025-05-11 20:08:14 +08:00
Soulter
238de4ba8c fix: 修复企业微信和微信公众平台下无法应用 api_base_url 的问题
fixes: #1505
2025-05-11 19:55:24 +08:00
Soulter
9a7bdade43 Merge pull request #1526 from AstrBotDevs/fix-weixin-kefu
Fix: 修复微信客服下接收消息时可能报错的问题
2025-05-11 19:46:14 +08:00
Soulter
aa84556204 🐛fix: 修复微信客服下接收消息时可能报错的问题
fixes #1504
2025-05-11 19:45:19 +08:00
Soulter
6b68069fcd Merge pull request #1525 from AstrBotDevs/fix-path-issue-cli
Fix: 修复 CLI 模式下路径问题导致 WebUI 和 MCP Server 无法加载的问题
2025-05-11 18:39:12 +08:00
Soulter
42c7034fb2 🐛 fix: 修复路径 2025-05-11 18:17:06 +08:00
Soulter
060c7e0145 🐛fix: 修复 CLI 模式下路径问题导致 WebUI 和 MCP Server 无法加载的问题 2025-05-11 18:09:36 +08:00
Soulter
b5b085dfb1 Merge pull request #1524 from AstrBotDevs/feat-provider-type-webui
Improve: 优化 WebUI 服务提供商的选择界面
2025-05-11 17:46:11 +08:00
Soulter
fc06ce9d7f perf: hint 2025-05-11 17:36:16 +08:00
Soulter
d8d81b05a7 feat: 更直观的模型提供商选择 2025-05-11 17:30:20 +08:00
Soulter
a60f42b1f2 feat: 在配置模板指定提供商能力类型 2025-05-11 04:04:05 -04:00
Soulter
6e18be88d0 Merge pull request #1519 from NanoRocky/master
Add Support for Azure TTS
2025-05-11 15:31:11 +08:00
Soulter
b45e439c48 Merge pull request #1520 from Raven95676/master
feat: 为部分组件提供register_to_file_service方法
2025-05-11 14:55:33 +08:00
Raven95676
b87061c18c feat: add file registration methods for audio, image, and file components 2025-05-11 10:08:55 +08:00
NanoRocky
f78aca7752 Fix provider_config by sourcery-ai 2025-05-11 02:15:37 +08:00
NanoRocky
3ccca2aa10 Update astrbot/core/provider/sources/azure_tts_source.py
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-11 02:11:03 +08:00
NanoRocky
6d7c40eb76 Fix AsyncClient 2025-05-11 01:54:44 +08:00
NanoRocky
da4cd7fb65 Add Support for Azure TTS 2025-05-11 01:20:17 +08:00
Soulter
c97cda6b84 Merge pull request #1517 from anchorAnc/fix-issue-1460
Fix issue 1460
2025-05-11 00:22:11 +08:00
Soulter
7a7fd4167a style: format code 2025-05-10 12:21:21 -04:00
Soulter
dffc1a43d5 Merge pull request #1518 from AstrBotDevs/fix-plugin-command
优化 plugin 指令的权限
2025-05-11 00:02:36 +08:00
Soulter
36897fea1e fix: 更正 plugin ls 指令提示 2025-05-10 12:01:49 -04:00
Soulter
c7b34735f0 fix: 更正 plugin help 指令提示 2025-05-10 12:00:48 -04:00
Soulter
5b07176c88 perf: 优化一些报错显示 2025-05-10 11:57:15 -04:00
Soulter
474b40d660 perf: 分离 plugin 指令为指令组,优化权限控制 2025-05-10 11:54:15 -04:00
Anchor
a62901b948 Merge branch 'AstrBotDevs:master' into fix-issue-1460 2025-05-10 23:02:18 +08:00
Anchor
25d8746327 补充一个import 2025-05-10 23:00:55 +08:00
Anchor
aff1698223 fix: 修复重启报错问题(关联 #1460)
使用subprocess.Popen启动新进程,修复原方案识别路径空格的问题
2025-05-10 22:54:38 +08:00
Raven95676
7f8941745f clean code 2025-05-10 22:51:50 +08:00
Raven95676
b858401098 chore: format code 2025-05-10 18:47:56 +08:00
渡鸦95676
d5a158b80f Merge pull request #1512 from Raven95676/Feature/cli-conf
feat: CLI支持部分配置文件项的设定
2025-05-10 16:42:53 +08:00
Raven95676
f315f284aa fix: improve error handling for config loading and setting 2025-05-10 16:24:52 +08:00
Raven95676
c367f5009d feat: CLI支持部分配置文件项的设定 2025-05-10 16:03:08 +08:00
渡鸦95676
6db1e63bda chore: add .astrbot to ignore file 2025-05-10 10:02:18 +08:00
渡鸦95676
e22ab2ede6 Merge pull request #1508 from Raven95676/master
fix: 设置thinking_budget前,先检查是否存在
2025-05-10 09:54:49 +08:00
Raven95676
b7d7e0b682 fix: 设置thinking_budget前,先检查是否存在 2025-05-10 09:51:30 +08:00
Raven95676
96bba15f2f chore: update version 2025-05-09 23:22:18 +08:00
Soulter
fcf965a595 Merge pull request #1480 from Raven95676/feature/cli
Feature: CLI功能增强,问题修复
2025-05-09 21:49:11 +08:00
渡鸦95676
e1a20d3c22 Merge branch 'master' into feature/cli 2025-05-09 20:22:33 +08:00
Soulter
2abd7d8c5d Merge pull request #1501 from AstrBotDevs/test
refactor: QQ 采用 http 回调的方式上报文件消息段中的文件信息。
2025-05-09 19:40:05 +08:00
Soulter
5b8f73cdd7 feat: 新增令牌超时时间 2025-05-09 07:29:37 -04:00
anka
7fd765421f fix: [File] remove unused tags "_downloaded" 2025-05-09 09:58:37 +00:00
Soulter
d9d94af022 perf: 优化异常处理和显示 2025-05-09 04:00:12 -04:00
Soulter
790b924e57 refactor: QQ 采用 http 回调的方式上报文件消息段中的文件信息。
fix: 修复 Lagrange 下合并转发消息失败的问题

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-09 03:47:19 -04:00
Soulter
4a62f877df 🐛 fix: 修复单独文件发送时被认为是空消息导致文件无法发送的问题 2025-05-09 10:45:50 +08:00
Raven95676
ac47c57bb7 perf: cli统一使用pathlib,修正typo 2025-05-08 20:25:12 +08:00
Soulter
3ace4199a1 📦 release: v3.5.8 2025-05-07 09:51:45 -04:00
Soulter
e6bd7524c1 🎈 perf: 优化 persona 错误显示 2025-05-07 09:49:07 -04:00
Soulter
699c86e8c1 Merge pull request #1486 from AstrBotDevs/feat-weixin-official-account
 feat: 支持微信公众平台
2025-05-07 21:00:27 +08:00
Soulter
f40fa0ecea chore: remove useless config 2025-05-07 08:59:48 -04:00
Soulter
626f94686b feat: 支持微信公众平台 2025-05-07 08:57:22 -04:00
Raven95676
752d13b1b1 perf: 优化 gemini_source 方法默认参数 2025-05-07 19:04:24 +08:00
Soulter
54c0dc1b2b docs(README.md): 个人微信接入说明 2025-05-07 14:50:24 +08:00
Soulter
c5bc709898 🎈 perf: 优化 openai_source 方法默认参数 2025-05-06 23:15:11 +08:00
Raven95676
ccdbb01513 perf: 修改move为copy,clean code 2025-05-06 18:39:11 +08:00
Raven95676
5206d750ac refactor: 减少重复和嵌套 2025-05-06 18:29:55 +08:00
Raven95676
a800e3df67 chore: 添加依赖 2025-05-06 18:18:15 +08:00
Raven95676
ccb1f87a20 feat: cli支持插件自动热重载;cli支持插件管理;cli支持指定Dashboard端口 2025-05-06 17:56:56 +08:00
Raven95676
c111da4681 refactor: 修改框架路径获取方式,规范化路径拼接 2025-05-06 17:30:34 +08:00
Soulter
9cc4e97a53 docs(README.md): update special thanks 2025-05-06 13:57:39 +08:00
Soulter
dca1c0b0f3 docs(README.md): update special thanks and platform 2025-05-06 13:56:26 +08:00
Raven95676
f06be6ed21 refactor: 拆分cli以便后续拓展功能 2025-05-06 00:53:00 +08:00
Soulter
3c8ec2f42e 📦 release: v3.5.7 2025-05-05 12:47:21 -04:00
Soulter
7e193f7f52 Merge pull request #1473 from AstrBotDevs/feat-wechat-kf
Feature: 支持接入微信客服
2025-05-06 00:15:37 +08:00
Soulter
7069b02929 chore: add license 2025-05-05 12:11:55 -04:00
Soulter
66995db927 feat: 支持微信客服图片消息 2025-05-05 12:08:23 -04:00
Soulter
c36054ca1b feat: 微信客服支持文本消息 2025-05-05 11:53:50 -04:00
Soulter
3e07fbf3dc feat: 微信客服 2025-05-05 11:32:35 -04:00
Soulter
bf3fbe3e96 fix: workflow job dependency 2025-05-04 19:52:27 +08:00
Soulter
0a93d22bc8 📦 release: v3.5.6 2025-05-04 12:46:40 +08:00
Raven95676
f5b3d94d16 fix: 修正thinking_config 2025-05-02 15:36:07 +08:00
Raven95676
4d1a6994aa fix: 保证Gemini anyOf 字段唯一 2025-05-02 10:56:05 +08:00
Raven95676
05c686782c Merge remote-tracking branch 'origin/master' 2025-05-02 10:51:01 +08:00
Raven95676
85609ea742 feat: 支持Gemini思考设置 2025-05-02 10:49:45 +08:00
Soulter
20dabc0615 Merge pull request #1333 from LIghtJUNction/master
Feature: 新增CLI命令行程序
2025-05-01 20:53:58 +08:00
Soulter
356dd9bc2b cd: upload to pypi 2025-05-01 20:48:11 +08:00
Soulter
cd5d7534c4 chore: imporove help message 2025-05-01 20:35:10 +08:00
LIghtJUNction
b4f12fc933 feat: supports CLI mode
Squashed by:

STEP1 - 新增CLI命令行程序

🎨 style: improve code style and some typo fixes

remove: llms.txt
2025-05-01 20:32:05 +08:00
Soulter
cbea387ce0 Merge pull request #1445 from AstrBotDevs/fix-download-file
Improve: 优化 QQ 下自动下载文件的问题
2025-05-01 20:15:06 +08:00
Soulter
345b155374 Merge pull request #1447 from anka-afk/1446-bug-mcp
fix: mcp 服务器页面搜索功能无法使用: 在前端实现搜索
2025-05-01 14:08:54 +08:00
Soulter
29d216950e Merge pull request #1427 from AstrBotDevs/fix-gewechat
Improve: 优化 Gewechat 下文件回调逻辑
2025-05-01 12:54:03 +08:00
anka
321b04772c refactor: 🍩将本地路径和url分离, 需要本地文件时提供下载接口, 同时向前兼容 2025-05-01 01:16:30 +08:00
anka
5b924aee98 Merge remote-tracking branch 'origin/1360-featurereset' into 1446-bug-mcp 2025-04-30 23:53:52 +08:00
anka
46d44e3405 fix: 🧩在前端实现mcp服务器的搜索 2025-04-30 23:52:55 +08:00
Raven95676
4d5332fe25 fix: 处理旧版本不存在ws_reverse_token的情况 2025-04-30 22:39:54 +08:00
Raven95676
18bd4c54f4 fix: 修正判断逻辑 2025-04-30 22:31:56 +08:00
Soulter
31c7768ca0 🎈 perf: 优化 QQ 下自动下载文件的问题 2025-04-30 21:47:14 +08:00
Raven95676
6ec643e9d1 fix: add self.lock 2025-04-30 00:51:49 +08:00
Soulter
2b39f6f61c Merge pull request #1426 from Raven95676/aiocqhttp-token
feat: 添加aiocqhttp对Token设置的支持
2025-04-30 00:04:52 +08:00
Soulter
bf3ca13961 Update astrbot/core/platform/sources/gewechat/client.py
Co-authored-by: 渡鸦95676 <Raven95676@gmail.com>
2025-04-30 00:03:21 +08:00
Soulter
82026370ec feat: 插件支持基于 Star 和 updated_at 排序 2025-04-29 11:17:00 +08:00
Soulter
6d49bf5346 fix: 修正 _handle_file 方法下的变量名 2025-04-28 23:49:36 +08:00
Soulter
67431d87fb fix: gewechat file 2025-04-28 23:31:45 +08:00
Raven95676
fdf55221e6 feat: 添加aiocqhttp对Token设置的支持 2025-04-28 22:14:51 +08:00
Soulter
07f277dd3b Merge pull request #1321 from XiGuang/master
bug: 修复私聊中接收引用消息无法准确获取用户昵称的问题
2025-04-26 23:21:22 +08:00
Soulter
cf8f0603ca 🐛 fix: gewechat 去除强制忽略自身消息的逻辑
fixes: #1388
2025-04-26 22:57:41 +08:00
Soulter
5592408ab8 Merge pull request #1386 from Raven95676/feature/mcp-img
feat: 处理MCP返回ImageContent、EmbeddedResource的情况,提供简单fallback
2025-04-26 21:29:14 +08:00
Soulter
a01617b45c fix: OneBot v11 request 类事件 补全 session_id 的获取 2025-04-26 21:00:30 +08:00
Soulter
7abb4087b3 Update README.md 2025-04-26 19:50:30 +08:00
渡鸦95676
dff15cf27a Merge pull request #1383 from Raven95676/feature/tg-optional-command
feat: 允许用户自定义telegram适配器指令注册行为,优化命令注册机制
2025-04-25 09:40:44 +08:00
Soulter
aa858137e5 Merge pull request #1240 from BigFace123/master
bug: 修复gewechat在群组中无法获取被at人的wxid问题
2025-04-25 00:51:11 +08:00
Soulter
45cb143202 perf: 实现解析微信群聊下对其他人的 At 2025-04-25 00:46:40 +08:00
Soulter
7a9c6ab8c4 Merge pull request #1374 from Raven95676/fix/gemini-func
fix: Gemini保证偶数索引为用户消息,奇数索引为模型消息
2025-04-23 23:27:10 +08:00
Raven95676
e2c26c292d feat: 处理MCP返回ImageContent、EmbeddedResource的情况,提供简单fallback 2025-04-23 19:55:15 +08:00
Soulter
be7c3fd00e docs: update PR template 2025-04-23 16:31:59 +08:00
Soulter
7e5461a2cf Merge pull request #1362 from anka-afk/1360-featurereset
feat: 😽对reset在不同情况下的权限特殊处理, 使其兼容alter_cmd 🤠为new指令增加清理上下文选项, 默认为清理, 更符合直觉
2025-04-23 16:21:20 +08:00
Raven95676
6ee9010645 feat: 允许用户自定义telegram适配器指令注册行为,优化命令注册机制 2025-04-23 15:53:18 +08:00
Raven95676
a23d5be056 refactor: 减少嵌套条件和重复代码 2025-04-23 12:49:27 +08:00
Raven95676
97a6a1fdc2 feat: 保证第一条消息不为model 2025-04-23 12:20:18 +08:00
Raven95676
c8f567347b feat: 修改重排序逻辑为合并连续相同类型的消息 2025-04-23 11:52:22 +08:00
anka
74c1e7f69e fix: ⚒️ 仍然清除聊天增强记录 2025-04-23 11:24:17 +08:00
anka
15a5fc0cae fix: 🧩revert logic of new func 2025-04-23 09:56:48 +08:00
Raven95676
f07c54d47c style: 减少一层 intent
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-04-23 00:48:25 +08:00
Soulter
70446be108 perf: catching a more specific exception type instead
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-04-23 00:08:03 +08:00
Soulter
d6d21fca56 Merge pull request #1347 from kkjzio/master
bug: 修复aiocqhttp平台使用指令组时,如果使用文本中携带网址无法识别指令
2025-04-23 00:00:04 +08:00
Raven95676
8d7273924f fix: Gemini保证偶数索引为用户消息,奇数索引为模型消息 2025-04-22 22:12:03 +08:00
Soulter
ea64afbaa7 docs: Update FUNDING.yml 2025-04-22 19:12:40 +08:00
Soulter
45da9837ec docs: Create FUNDING.yml 2025-04-22 19:12:03 +08:00
Raven95676
8c19b7d163 chore: clean code,format 2025-04-22 17:52:25 +08:00
Raven95676
ab227a08d0 fix: 修复openai source中e的作用域问题 2025-04-22 11:50:47 +08:00
anka
40d6e77964 fix: 🫓使用enum代替字典后的一些修改 2025-04-22 11:16:24 +08:00
anka
9326e3f1b0 refactor: 使用enum代替字典
Co-authored-by: 渡鸦95676 <Raven95676@gmail.com>
2025-04-22 10:55:32 +08:00
kkjz
0e1eb3daf6 fix: 使用join方法优化相邻文本段合并 2025-04-21 20:56:18 +08:00
anka
05daac12ed refactor: 🍔降低复杂性 2025-04-21 12:35:08 +08:00
anka
c5b24b4764 feat: 🤠为new指令增加清理上下文选项, 默认为清理, 更符合直觉 2025-04-21 12:06:20 +08:00
anka
cc16548e5f feat: 😽对reset在不同情况下的权限特殊处理, 使其兼容alter_cmd 2025-04-21 11:56:12 +08:00
Soulter
291d65bb3e release: v3.5.5 2025-04-21 11:09:18 +08:00
Soulter
bd3ad03da6 Merge pull request #1361 from AstrBotDevs/hotfix/webui-mcp
fix: 修复 MCP 页面的一些问题
2025-04-21 10:54:19 +08:00
Soulter
5fa6788357 chore: properly storing interval ID for cleanup. 2025-04-21 10:54:06 +08:00
Soulter
c5c5a98ac4 🐛 fix: 修复 MCP 页面的一些问题 2025-04-21 10:51:01 +08:00
Soulter
a1151143cf Merge pull request #1357 from Raven95676/hotfix/gemini-functool
fix: 修复get_func_desc_google_genai_style未正确转换函数调用的问题
2025-04-21 10:26:44 +08:00
Raven95676
f5024984f7 perf: 移除冗余判断 2025-04-21 00:55:20 +08:00
Raven95676
f4880fd90d fix: 修复get_func_desc_google_genai_style未正确转换函数调用的问题 2025-04-21 00:11:31 +08:00
kkjz
0ae61d5865 fix: 修复生成text的Plain时文本为处理后的文本 2025-04-20 22:11:24 +08:00
kkjz
d3bd775a79 feat: 使用groupby来合并aiocqhttp连续的文本段 2025-04-20 18:09:04 +08:00
Soulter
da546cfe7f 🎈 perf(telegram): 弱化无法注册指令的日志级别 2025-04-20 18:08:52 +08:00
Soulter
a211933e83 📦 release: v3.5.4 2025-04-20 18:01:37 +08:00
Soulter
1d40b5a821 feat(updator): 替换为采用 Semver 语义化版本来比较版本 2025-04-20 17:30:01 +08:00
Soulter
33836daeb7 Merge pull request #1327 from YOOkoishi/tts-feat-branck
TTS : add text output alongside voice (Fix #1085)
2025-04-20 16:07:06 +08:00
Soulter
d921b0f6bd 🎈 perf: 优化 gewechat 的引用消息解析 2025-04-20 16:00:59 +08:00
Soulter
0607b95df6 🎈 perf: 增强异常处理 2025-04-20 15:40:51 +08:00
Soulter
0de6d0e046 Merge pull request #1256 from Raven95676/better-stream
perf: 为不支持流式输出的平台提供fallback。
2025-04-20 15:24:31 +08:00
kkjz
98427345cf bug: 修复aiocqhttp平台使用指令组时,如果使用文本中携带网址无法识别指令 2025-04-20 12:04:02 +08:00
Soulter
9fedaa9f77 🎈perf(webui): 优化了 MCP 页面的效果 2025-04-20 11:26:53 +08:00
Soulter
bf4c2ecd33 feat: MCP 支持 SSE 传输协议连接到服务器 2025-04-20 11:02:28 +08:00
Soulter
f8c18cc1e0 Merge pull request #1341 from AstrBotDevs/fix-dashscope-error-1330
fix: 修复阿里云百炼 TTS 只能发送一次语音,第二次就会报错
2025-04-20 01:17:32 +08:00
Soulter
458b900412 Merge pull request #1340 from AstrBotDevs/perf-wecom-split-long-text
feature: 企业微信添加长文本分割功能以支持发送超过 2048 字符的消息
2025-04-20 01:15:48 +08:00
Soulter
192c776e0b 🐛 fix: 修复阿里云百炼 TTS 只能发送一次语音,第二次就会报错
fixes: #1330
2025-04-20 00:58:37 +08:00
anka
5cdec18863 improvement: 对标点符号分割而不是直接切分 2025-04-19 16:52:30 +00:00
Soulter
15f856f951 perf(wecom): 企业微信添加长文本分割功能以支持发送超过 2048 字符的消息
fixes: #564
2025-04-20 00:27:04 +08:00
Raven95676
01d52cef74 perf: 支持更多参数 2025-04-20 00:12:14 +08:00
XiGuang
95563c8659 bug fix: 更新引用嵌套消息解析逻辑,支持图片处理 2025-04-19 16:15:47 +08:00
YOO_koishi
31d8c40eca tts : add text output alongside voice (Fix #1085) 2025-04-19 14:44:02 +08:00
渡鸦95676
56001ed272 Merge pull request #1326 from Raven95676/session_waiter
perf: 修改默认会话过滤器标识符为umo
2025-04-19 13:45:06 +08:00
XiGuang
d916fda04c feat: 增强消息处理逻辑,支持引用嵌套消息解析 2025-04-19 12:10:51 +08:00
Raven95676
cfae655068 perf: 修改默认会话过滤器标识符为umo 2025-04-19 11:57:22 +08:00
Raven95676
5596565ec4 fix: 若启用Gemini原生工具,构建Content列表时忽略工具调用 2025-04-18 23:36:12 +08:00
XiGuang
afa1aa5d93 🐛 fix: 更新用户真实姓名获取逻辑,改为从用户信息中提取 2025-04-18 21:22:46 +08:00
Raven95676
e98c3d8393 fix: Gemini保证工具间的互斥 2025-04-18 16:19:36 +08:00
渡鸦95676
6687b816f0 Merge pull request #1303 from Raven95676/master
feat: 添加对Gemini原生搜索功能的支持
2025-04-17 20:48:02 +08:00
Raven95676
ea8035e854 feat: 添加对Gemini原生搜索功能的支持 2025-04-17 20:36:22 +08:00
Soulter
54b0171d49 Merge pull request #1296 from AstrBotDevs/feat-mcp-servers-market
[WIP] MCP 服务器市场
2025-04-17 16:26:41 +08:00
Soulter
676d4277b9 chore: 优化样式 2025-04-17 16:26:27 +08:00
Soulter
a4b1da3ca2 perf: 警告 2025-04-17 16:24:50 +08:00
Soulter
9e9c16e770 Merge pull request #1295 from EdelweissHuirh/master
修改分段回复的分割逻辑
2025-04-17 16:11:08 +08:00
Soulter
dc87006fed feat: 分页 2025-04-17 16:07:13 +08:00
Soulter
b9b260f26a perf: 弱化显示 2025-04-17 14:02:40 +08:00
Soulter
33fd6a5016 perf: 优化 MCP 服务器的日志回显 2025-04-17 13:59:10 +08:00
Soulter
97cbccc2ba feat: mcp 服务器市场 2025-04-17 00:41:04 +08:00
Raven95676
1ee4685d5d perf: 允许行级别锚点匹配以保持一致性 2025-04-16 22:13:38 +08:00
Soulter
aba18232b1 perf: docker 镜像自带 node 环境
fixes: #1290
2025-04-16 21:53:27 +08:00
huirh
0a02441b75 修改分段回复逻辑 2025-04-16 21:52:42 +08:00
Raven95676
1be5b4c7ff fix: 兼容旧版本google-genai sdk 2025-04-16 00:34:08 +08:00
Raven95676
a0ce0cf18a fix: 增加更多Gemini不支持多模态输出的情况 2025-04-16 00:11:46 +08:00
Soulter
7c54e5d093 perf: 优化已安装的插件页
fixes: #934
2025-04-15 22:53:40 +08:00
Soulter
b825e51dab chore: clean useless logs 2025-04-15 21:56:23 +08:00
Soulter
589855c393 feat: 支持开关是否忽略自身发送的消息
某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人

fixes: #890
2025-04-15 21:55:21 +08:00
渡鸦95676
4c546f2f53 Merge branch 'master' into better-stream 2025-04-15 21:22:08 +08:00
Raven95676
3753fce912 perf: 为发送流式消息的Fallback可选 2025-04-15 21:21:02 +08:00
Soulter
4c02857ec5 🐛 fix: 修复 aiocqhttp 无法发图片
fixes: #1275
2025-04-15 21:15:39 +08:00
Soulter
33f87ff7d7 🎈 perf: enhance metrics tracking with installation ID and sender ID hashing 2025-04-15 21:08:45 +08:00
Soulter
784dcf2a9a Merge pull request #1228 from Raven95676/gemini
refactor: 使用Google官方SDK重构gemini_source
2025-04-15 20:04:20 +08:00
Soulter
43ee943acb 🐛 fix: 多轮函数调用的报错 2025-04-15 10:59:16 +08:00
Soulter
a769fd7d13 chore: add google-genai dependency to project 2025-04-15 10:40:42 +08:00
渡鸦95676
2c4fd00b16 Merge pull request #1276 from Raven95676/master
fix: 移除TG注册命令时的调试信息,注册命令时添加合法性校验
2025-04-14 22:04:11 +08:00
Raven95676
264771fe98 fix: 移除注册时的调试信息,注册命令时添加合法性校验 2025-04-14 21:55:34 +08:00
Soulter
ecd92dafef Merge pull request #1274 from AstrBotDevs/fix-1121
🐛 fix: 修复上下文带图的情况下,对话数据库页无法查看对话详情的问题
2025-04-14 21:35:54 +08:00
Soulter
c8b6e4bea3 🐛 fix: 修复上下文带图的情况下,对话数据库页无法查看对话详情的问题
fixes: 1121
2025-04-14 21:34:11 +08:00
Soulter
3756cb766e 🎈 perf: 支持自定义 PyPI 软件仓库地址
fixes: #1165
2025-04-14 21:19:36 +08:00
Soulter
068d9ca60b Update README.md 2025-04-14 19:57:04 +08:00
Soulter
93f632d8b8 Update README.md 2025-04-14 19:56:32 +08:00
Soulter
bb44ce7e74 Update README.md 2025-04-14 10:30:12 +08:00
Raven95676
6986c8d8f7 fix: clean code,处理Gemini流式输出最后一部分概率性为None的情况 2025-04-13 18:34:57 +08:00
Raven95676
fe95506db4 perf: 添加日志过滤器以抑制非文本部分警告信息 2025-04-13 17:50:44 +08:00
Raven95676
310ed76b18 fix: 仅在确实包含图片模态时降级 2025-04-13 17:28:34 +08:00
Raven95676
98830d147f fix: 限速增加到1.5秒 2025-04-13 17:14:51 +08:00
Raven95676
19c9177d7b chore: 移除对dingtalk、lark、wecom的fallback 2025-04-13 17:03:06 +08:00
渡鸦95676
f41c5f97f6 Merge branch 'master' into better-stream 2025-04-13 16:47:56 +08:00
Raven95676
648c125697 refactor: 提取缓冲处理逻辑到astr_message_event 2025-04-13 15:37:22 +08:00
Soulter
0dc2b89897 Merge pull request #1257 from KimigaiiWuyi/master
🐛 修复飞书适配器转换消息过程中无法正确转化Base64图片
2025-04-13 15:33:02 +08:00
Soulter
83745f83a5 🐛 fix: 对飞书适配器 base64 格式数据先保存到本地 2025-04-13 15:29:56 +08:00
Soulter
2f91fe4535 Merge pull request #1244 from Rail1bc/master
修复:dequeue_context_length的配置项的实际行为与描述不一致;调用函数工具可能导致400错误
2025-04-13 14:41:16 +08:00
Raven95676
739f09059e feat: 为Gemini原生代码执行器提供有限支持 2025-04-13 12:43:25 +08:00
渡鸦95676
c86f9f0f5f Merge pull request #1261 from Raven95676/master
fix: 修复文件不存在的情况
2025-04-13 11:40:33 +08:00
Raven95676
9470ca6bc5 fix: 修复文件不存在的情况 2025-04-13 11:36:06 +08:00
Raven95676
2a92c4d5de fix: 修复导入 2025-04-13 11:22:27 +08:00
Raven95676
bb6e892657 feat: 重构发送流以提高代码可读性 2025-04-13 11:19:40 +08:00
KimigaiiWuyi
c9079b9299 🐛 修复飞书适配器转换消息过程中无法正确转化Base64图片 2025-04-13 06:06:02 +08:00
Raven95676
b6963c1bf9 perf: 为不支持流式输出的平台提供fallback。 2025-04-13 02:21:42 +08:00
Raven95676
9c29df47bb fix: 更新流式输出逻辑,禁用图片模态并添加日志警告。 2025-04-13 01:09:42 +08:00
Soulter
fc146d3d00 Merge pull request #1245 from AstrBotDevs/perf-mcpserver
perf: 适配 MCP 配置文件带 mcpServers 的情况(Cursor)
2025-04-12 23:06:39 +08:00
Soulter
1bf5a21678 Merge pull request #1158 from Jackxwb/master
文件发送时支持路径映射
2025-04-12 21:01:25 +08:00
Soulter
011542dc2b Merge pull request #1247 from Raven95676/shared_preferences
perf: shared_preferences加载失败时自动删除无效文件
2025-04-12 20:04:19 +08:00
Raven95676
489784104e perf: shared_preferences加载失败时自动删除无效文件 2025-04-12 19:31:45 +08:00
Raven95676
3860634fd2 fix: 修复了多模态输出支持判断问题并对只输出图片的情况进行处理。 2025-04-12 19:15:39 +08:00
Soulter
709c324e18 🐛 fix: 修复 MCP 服务器配置处理逻辑,确保正确处理空 mcpServers 情况并优化代码可读性 2025-04-12 18:19:06 +08:00
Soulter
b75d24d92c 🎈 perf: 适配 MCP 配置文件带 mcpServers 的情况(Cursor)
🐛 fix: 关闭/删除 MCP 服务器后 Tools 没有清除的问题
2025-04-12 17:56:23 +08:00
Raila23
ed80e9424c Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot 2025-04-12 16:28:14 +08:00
Raila23
2fe1f2060a 修复:调用函数工具或其他未知情况,可能导致400 BadRequestError 2025-04-12 16:26:02 +08:00
Raila23
c6df820164 修复:每次清除的消息,比实际上期望的多1条 2025-04-12 15:34:35 +08:00
Soulter
d6239822db release: v3.5.3.2 2025-04-12 15:27:33 +08:00
Soulter
bced9ffff9 🐛 fix: 修复zhipu工具调用问题 2025-04-12 15:24:37 +08:00
Soulter
d7d1c1544a 🐛 fix: 修复重启bot时可能发生报错的问题
在 gewechat, wecom 等消息平台没启动成功的情况下重启bot会报错
2025-04-12 15:01:38 +08:00
BigFace123
7c1e8ce48c 添加gewechat被at人wxid获取,AstrBotMessage添加be_at_wxid字段 2025-04-12 10:17:42 +08:00
Soulter
e3b0ca8ef6 🐛 fix: 改进版本号比较逻辑以支持任意长度的版本号 2025-04-12 10:00:25 +08:00
Soulter
9e266eb6d5 release: v3.5.3.1 2025-04-12 09:48:49 +08:00
Soulter
7231403e16 🐛 fix: xai missing field parameters 2025-04-12 09:47:11 +08:00
Soulter
344a486fd7 fix: entites 前向兼容 2025-04-12 09:10:54 +08:00
Soulter
4fd831875d Merge pull request #1237 from AstrBotDevs/release/v3.5.3
📦 release: v3.5.3
2025-04-12 01:04:31 +08:00
Soulter
0988d067ea 📦 release: v3.5.3 2025-04-12 00:58:45 +08:00
Raven95676
44dbe475af refactor: 拆分方法以提高代码可读性 2025-04-12 00:23:57 +08:00
Raven95676
bd24cf3ea4 feat: 初步完成原生流式请求逻辑 2025-04-11 23:45:30 +08:00
Raven95676
b493a808fe fix: 处理更多多模态不支持错误 2025-04-11 20:25:20 +08:00
Raven95676
54035d108d Merge branch 'gemini' of https://github.com/Raven95676/AstrBot-Rdev into gemini 2025-04-11 18:57:55 +08:00
Raven95676
c5e8bc7e20 fix: 修复模型生成内容的重试机制。 2025-04-11 18:55:46 +08:00
渡鸦95676
3bbb4779a3 Merge branch 'master' into gemini 2025-04-11 18:15:44 +08:00
Raven95676
1b3963ebea fix: 更新类型提示,简化代码并修复潜在的空值问题。 2025-04-11 18:07:00 +08:00
Soulter
3b6dd7e15a 🐛 fix: 修复 dify 下删除对话的报错问题
fixes: #1226
2025-04-11 17:27:29 +08:00
Soulter
757d2a3947 🐛 fix: 更新 Dify API 类型提示,增加对 Chatflow 应用类型的说明 2025-04-11 17:23:26 +08:00
Soulter
61b71143f2 Merge pull request #1223 from MR-pofeng/tag-msg-seq
feat:为QQ官方接口需要msg_seq的playload添加随机msg_seq
2025-04-11 16:25:46 +08:00
Soulter
1b343a36c9 Merge pull request #1174 from anka-afk/anka-dev
对关闭的#1167提供完整修复, 修复gemini请求content为空的情况, 增加上下文中验证toolcall逻辑
2025-04-11 16:20:30 +08:00
Soulter
8e94937060 🐛 fix: 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具
fixes: #863 #1150
2025-04-11 15:50:36 +08:00
Raven95676
e8ffebc006 fix: 修复消息处理流程中可能出现的空消息 2025-04-11 15:01:20 +08:00
Raven95676
2ca95eaa9f fix: 在设置新key后重新初始化Gemini客户端 2025-04-11 14:42:24 +08:00
Raven95676
0dc5b4cdfc perf: 增加对RECITATION完成原因的处理,提取内容处理逻辑到独立方法 2025-04-11 12:25:44 +08:00
Raven95676
cc6cd96d8e fix: 修复潜在的空消息 2025-04-11 11:03:17 +08:00
Raven95676
4244d37625 chore: 格式化代码,禁用gemini source debug输出 2025-04-11 01:06:20 +08:00
Raven95676
0b766095d4 refactor: 初步完成gemini_source的重写 2025-04-11 01:03:16 +08:00
Soulter
a4f212a18f 🐛 fix: 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题
fixes: #1060
2025-04-11 00:20:08 +08:00
Soulter
caafb73190 🐛 fix: 修复函数调用的一些bug 2025-04-10 23:28:51 +08:00
kuangfeng
09482799c9 feat:为需要msg_seq的playload添加随机msg_seq 2025-04-10 21:43:12 +08:00
Soulter
37f93d1760 Merge pull request #1175 from Raven95676/telegram
feat: 自动注册指令到Telegram
2025-04-10 20:26:54 +08:00
Soulter
725f2e5204 Merge pull request #1212 from AstrBotDevs/feat-lark-active-message
 feat: 支持飞书平台下主动消息发送
2025-04-10 17:14:37 +08:00
Soulter
967198fae0 feat: 支持飞书平台下主动消息发送
fixes: #1177

WARNING:
这个修复会导致开启对话隔离下飞书群组的对话记录丢失(但没有被删除)。
2025-04-10 17:12:26 +08:00
Soulter
43d57f6dcb 🎈 perf: Add type validation for configuration items in validate_config function 2025-04-10 15:56:14 +08:00
Soulter
6afa4db577 Merge pull request #1208 from Rail1bc/fix_begin_dialogs
fix:使 begin_dialogs ,预设对话,不会多次插入
2025-04-10 15:32:10 +08:00
Soulter
3b8c3fb29a Merge pull request #1207 from zsbai/patch-1
修复了 `event.get_sender_id()` 返回值与函数注释不一致的问题
2025-04-10 15:27:14 +08:00
Soulter
921c3b0627 Merge pull request #1203 from Rail1bc/master
将一项优化插件的简单逻辑,适配到Core中
2025-04-10 15:25:00 +08:00
Raila23
c0fadb45ab 添加更详细的描述 2025-04-10 15:20:56 +08:00
Raven95676
a1481fb179 群聊场景命令特殊处理 2025-04-10 14:54:25 +08:00
Soulter
987cd972d3 Merge pull request #1180 from Raven95676/reload
perf: 确保完整处理插件所有模块。
2025-04-10 14:45:28 +08:00
anka
bdf25976a3 fix: 少打一个字 2025-04-10 11:28:47 +08:00
anka
87c3aff4ce perf: 简化llm_request工具调用消息成对验证逻辑, 合并两处验证逻辑到一个函数 2025-04-10 11:25:03 +08:00
anka
99350a957a Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-10 11:16:49 +08:00
Soulter
319068dc7e Merge pull request #1179 from zhx8702/feat-platform-plugin-control
feat: 添加插件能针对不同消息平台开启关闭的功能
2025-04-10 11:02:09 +08:00
Soulter
cd18806c39 perf: improve platform compatibility checks 2025-04-10 11:01:04 +08:00
Raila23
95b08b2023 fix:使 begin_dialogs ,预设对话,不会多次插入 2025-04-10 09:18:58 +08:00
baiiylu
0e70f76c86 fix: wrong type of sender_id returned in event.get_sender_id() 2025-04-10 08:03:38 +08:00
Raila23
4d414a2994 增加dequeue_context_length的值的判断,只能在1到max_context_length之间 2025-04-09 22:28:33 +08:00
Raila23
3d22772d4e 新增配置项,允许配置:超出最多携带对话数量 时,一次性丢弃多少条旧消息 2025-04-09 22:12:02 +08:00
Raila23
0b381e2570 新增配置项,允许配置:超出最多携带对话数量 时,一次性丢弃多少条旧消息 2025-04-09 22:10:56 +08:00
Raven95676
f2cc4311c5 fix: optional value 2025-04-09 18:55:20 +08:00
Raven95676
e349671fdf format 2025-04-09 18:45:40 +08:00
Raven95676
01c02d5efa perf: 提取模块清理逻辑到 _purge_modules 方法 2025-04-09 18:11:35 +08:00
zhx
b62b1f3870 feat: 添加插件能针对不同消息平台开启关闭的功能
Squashed:

chore: merge master branch

chore: merge from master branch

chore: rename updateAllPlatformCompatibility to update_all_platform_compatibility for consistency

Reviewed by:

@Raven95676 @Soulter
2025-04-09 17:27:44 +08:00
Soulter
8844830859 Merge pull request #1194 from Raven95676/tools
feat: StarTools添加数据目录获取接口
2025-04-09 16:53:22 +08:00
Soulter
0c51ee4b64 chore: 依赖顺序 2025-04-09 16:53:06 +08:00
Soulter
11920d5e31 docs: add a badge to show plugins num 2025-04-09 16:41:32 +08:00
Raven95676
848ea1eb63 提升健壮性 2025-04-09 16:37:19 +08:00
渡鸦95676
a216519486 Merge branch 'AstrBotDevs:master' into tools 2025-04-09 16:16:26 +08:00
Raven95676
b04606c38e 新增获取数据目录的StarTool 2025-04-09 16:13:48 +08:00
Soulter
38072beea7 🎈 perf: 优化插件市场显示 2025-04-09 15:47:44 +08:00
Soulter
b843f1fa03 Update PULL_REQUEST_TEMPLATE.md 2025-04-09 15:28:18 +08:00
Soulter
560d40e571 Merge pull request #1184 from kterna/master
feat:查看本地插件readme和市场插件star数
2025-04-09 15:23:50 +08:00
Soulter
5f0b8161b7 perf: 优化 WebUI Chat 的流式传输性能 2025-04-09 15:22:35 +08:00
kterna
062d482917 fix 2025-04-09 08:43:16 +08:00
Soulter
39693a27e3 Merge branch 'master' into master 2025-04-09 00:30:51 +08:00
anka
7cd1eeac30 fix: 直接把空字符串改为" "一条消息的content是空字符串 2025-04-08 15:57:38 +00:00
Soulter
bafa473c8e Merge pull request #1157 from AstrBotDevs/feat-streaming
feature: 支持流式输出
2025-04-08 22:53:38 +08:00
Soulter
750cf46b2e 🎈 perf: better ChatPage UI 2025-04-08 17:33:46 +08:00
kterna
68885a4bbc Update astrbot/dashboard/routes/plugin.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-04-08 16:30:36 +08:00
Soulter
bcc99a8904 🐛 fix: 修复 permission 过滤算子的 raise_error 参数失效的问题 2025-04-08 14:42:05 +08:00
kterna
59fbd98db3 1 2025-04-08 14:31:35 +08:00
kterna
b70ed425f1 Merge branch 'master' of https://github.com/kterna/AstrBot 2025-04-08 14:05:43 +08:00
kterna
45ef5811c8 1 2025-04-08 14:02:59 +08:00
kterna
3b137ac762 插件管理中查看本地插件的readme 2025-04-08 14:01:14 +08:00
kterna
1ddb0caf73 star显示 2025-04-08 10:47:59 +08:00
Raven95676
ae4c6fe2dd 优化,确保完整处理插件所有模块。为核心方法添加文档。 2025-04-08 10:41:47 +08:00
Jackxwb
b03fe438d0 Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot 2025-04-07 22:50:03 +08:00
Raven95676
db257af58e 提升代码可读性 2025-04-07 22:29:50 +08:00
Raven95676
735368c71b 保证变量名可读性 2025-04-07 22:16:02 +08:00
Raven95676
9e04e3679b 保证内置插件指令被注册 2025-04-07 22:08:29 +08:00
Raven95676
43b8414727 初步实现指令注册 2025-04-07 21:51:41 +08:00
anka
5a00187147 fix: 对历史记录的toolcall验证是否成对, 参考:
https://github.com/run-llama/llama_index/issues/13715
https://github.com/run-llama/llama_index/pull/16214
2025-04-07 18:14:30 +08:00
Raven95676
cb525c7c84 更新下hint( 2025-04-07 17:56:10 +08:00
anka
d88420dd03 fix: 修改获取人类可读的上下文的逻辑, 区分函数调用(无contents)和一般消息 2025-04-07 17:55:12 +08:00
anka
b9a983f8e0 fix: 为函数调用历史记录增加标记, 不读取入上下文 2025-04-07 17:45:35 +08:00
Raven95676
42431ea7db 统一text_chat_stream fallback 2025-04-07 17:43:35 +08:00
Raven95676
f9459e4abb 修复无法通过yield发送消息的问题 2025-04-07 17:38:23 +08:00
anka
72f917d611 fix: gemini只在content不为空的时候加入上下文 2025-04-07 17:31:57 +08:00
Raven95676
9fd1d19e93 分离流式与非流式响应处理 2025-04-07 11:52:29 +08:00
Soulter
062af1ac08 🎈 perf: 优化 WebUI 日志错误处理 2025-04-07 10:38:03 +08:00
Raven95676
41bd76e091 tg适配器最后一次编辑转换markdown 2025-04-07 00:47:52 +08:00
Raven95676
cfd3f4b199 流式输出完成后,将完整的LLM响应设置为事件结果 2025-04-07 00:17:53 +08:00
Soulter
79d38f9597 📦release: v3.5.2 2025-04-06 22:36:31 +08:00
Soulter
b3866559e1 📦release: v3.5.2 2025-04-06 22:35:10 +08:00
Soulter
4d186baa35 Merge pull request #1128 from anka-afk/anka-dev
feature: 实现了 #1127 还有 #1133 还有 #1143
2025-04-06 22:22:01 +08:00
anka
8ed3d5f3db fix: 将openai_source的结果消息链的构造方式和其他统一 2025-04-06 09:12:52 +00:00
anka
f0c8f39b6d 对tg的通过编辑消息的流式传输完善错误捕获 2025-04-06 08:57:18 +00:00
anka
431db8fc9b 对流式输出做错误捕获 2025-04-06 08:47:17 +00:00
anka
ba252c5356 fix: 修正一个偶然发现的命名错误() 2025-04-06 08:12:00 +00:00
Raven95676
a2812c39c0 修正文档注释 2025-04-06 16:05:21 +08:00
Raven95676
0490758820 替换原地修改和删除索引的旧逻辑 2025-04-06 15:36:05 +08:00
Jackxwb
7f56824b42 🐛 修复: 移除路径映射函数中的多余日志记录 2025-04-06 14:52:34 +08:00
Jackxwb
627da3a2bc 分离path_Mapping函数 2025-04-06 14:50:15 +08:00
Soulter
9b36a5c8a6 feat: 增加全平台对流式输出的处理逻辑 2025-04-06 13:43:23 +08:00
Soulter
c1cf2be533 feat: 完善流式处理 2025-04-06 11:56:06 +08:00
Jackxwb
e6b69042de 文件发送时支持路径映射 2025-04-06 01:06:51 +08:00
Soulter
109650faf3 feat: 支持流式输出 2025-04-06 00:56:33 +08:00
Raven95676
e54eaab842 将验证器字典移到类级别,避免重复创建 2025-04-05 21:19:53 +08:00
Raven95676
43b6297b5d reminder将时区设置移入try块,统一为self.timezone 2025-04-05 21:08:52 +08:00
Raven95676
c20f4f5adf 删除默认值,调整logger逻辑 2025-04-05 21:03:02 +08:00
Soulter
dc1f222cd2 fix: 使用 zoneinfo 替代 tzinfo; 默认不设置时区(使用系统默认时区) 2025-04-05 17:27:46 +08:00
Soulter
c2b687212c cleanup 2025-04-05 16:51:06 +08:00
Soulter
849913276d 🎈 perf: 钉钉支持 Markdown 渲染输出
fixes: #1104
2025-04-05 16:29:14 +08:00
Soulter
23579c1e4a 🐛 fix: 阿里百炼应用无法多轮会话
fixes: #1123
2025-04-05 16:21:41 +08:00
Soulter
e031161fd4 🐛 修复: 移除文本输入框的 auto-grow 属性
fixes: #1038
2025-04-05 15:58:17 +08:00
Soulter
4800ee6c0a Merge pull request #1152 from AstrBotDevs/feat-log-filter
 feat: 更新日志发布机制,支持日志级别和内容的字典格式,增加日志筛选功能
2025-04-05 15:49:09 +08:00
Soulter
d3a7fef9b0 🐛 修复: 移除多余的 console 语句 2025-04-05 15:46:45 +08:00
Soulter
40822fe77a feat: 更新日志发布机制,支持日志级别和内容的字典格式,增加日志筛选功能
fixes: #1010
2025-04-05 15:43:40 +08:00
Soulter
837b670213 feat(webui): 支持修改列表项
fixes: #1086
2025-04-05 15:10:44 +08:00
Soulter
57ce69f3fb feat: WebChat 支持语音输出
fixes: #1087
2025-04-05 15:02:34 +08:00
anka
be022c4894 fix: add StarTools to api 2025-04-05 11:55:25 +08:00
anka
8a366964bb feature: 增加时区设置支持 2025-04-05 11:52:51 +08:00
anka
ee86b68470 fix: 漏加classmethod了! 2025-04-05 01:15:56 +08:00
anka
60352307aa fix: 重生之我要苦读设计模式, 终于知道怎么整了哈哈哈: 使用静态类实现工具集合, 并且正确初始化 2025-04-05 01:11:10 +08:00
anka
3ebd2f746f feature: 添加插件工具类, 暂时这么多 2025-04-05 00:51:52 +08:00
anka
1c1a65b637 fix: 全部消息段的检验弄好了! 2025-04-05 00:21:28 +08:00
anka
010e60d029 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-04 23:13:43 +08:00
Soulter
7a25568861 Merge pull request #1131 from AliveGh0st/feature/gemini-safety-settings
feature:增加对Gemini系列模型的安全设置参数支持
2025-04-04 21:22:58 +08:00
AliveGh0st
5f4f913661 feat: 增加对 Gemini 系列模型的输入安全设置参数支持
fixes: #216

Squashed:

Update astrbot/core/config/default.py

描述更正.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

🎨 style: clean up

🐛 fix: 修复安全设置参数的默认值为列表
2025-04-04 21:12:51 +08:00
Soulter
ccd0e34a53 Merge pull request #1145 from AstrBotDevs/feat-telegram-markdownv2
 feat: 支持 Telegram MarkdownV2 渲染
2025-04-04 20:54:04 +08:00
Soulter
72f1ffccd3 feat: 支持 Telegram MarkdownV2 渲染
fixes: #649 #907
2025-04-04 20:52:22 +08:00
Soulter
ea7a52945f Merge pull request #1132 from Captain-Slacker-OwO/dify-md
docs: 更新 Dify 平台链接为官方域名
2025-04-04 01:12:19 +08:00
Soulter
89d4d1351a Merge pull request #1135 from AstrBotDevs/feat-dashscope-tts
feat: 支持阿里云百炼 TTS
2025-04-04 01:03:36 +08:00
Soulter
b757c91d93 🐛 fix: 修复无法识别到函数调用异常的问题 2025-04-04 01:02:39 +08:00
Soulter
27203d7a4d 🐛 fix: update voice key name 2025-04-04 00:47:50 +08:00
Soulter
9ad4e18ac5 feat: 支持阿里云百炼 TTS 2025-04-04 00:32:37 +08:00
anka
fcdc8f3ce7 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-03 21:57:24 +08:00
Captain-Slacker-OwO
78b994b84a docs: 更新 Dify 平台链接为官方域名
将 README 文件中的 Dify 平台链接从旧域名更新为官方域名 dify.ai,确保文档的准确性和权威性。
2025-04-03 19:00:44 +08:00
Soulter
58bfc677e2 🐛 fix: dify error Arg user must be provided
fixes #1073
2025-04-03 16:49:05 +08:00
Soulter
7d17285a0c 🐛 fix: ensure whitelist entries are stripped of whitespace and converted to strings 2025-04-03 16:44:37 +08:00
Soulter
e9eb00a0d4 feat: 插件市场帮助按钮 2025-04-03 16:19:01 +08:00
anka
48d07af574 feature(fix?): 在发送消息之前统一检查消息内容是否为空, 不允许发送空消息, 以解决该消息内容不支持查看以及gemini返回<empty content>问题 2025-04-03 11:50:12 +08:00
Soulter
2fc62efd88 Merge pull request #1116 from AstrBotDevs/feat-log-sse
🏗 refactor: log 通信使用 SSE 替代 Websockets
2025-04-02 21:07:40 +08:00
Soulter
be516d75bd 🐛 fix: upadte method name 2025-04-02 21:06:59 +08:00
Soulter
951d5fde85 🏗 refactor: log 通信使用 SSE 替代 Websockets 2025-04-02 20:59:25 +08:00
Soulter
1389abc052 Merge pull request #1112 from AstrBotDevs/fix-aiocqhttp-empty-plain
修复 aiocqhttp 适配器下空白 plain 导致的报错
2025-04-02 16:27:12 +08:00
Soulter
19ad67a77f 🐛 fix: 修复 aiocqhttp 适配器下空白 plain 导致的 the object is not a proper segment chain 报错问题 2025-04-02 16:24:36 +08:00
Soulter
641f308344 Update README.md 2025-04-01 11:35:56 +08:00
Soulter
9f097fa4d5 Update README.md 2025-04-01 11:33:38 +08:00
Soulter
5ad362c52b Merge pull request #1081 from anka-afk/anka-dev
fix #1074 and add some comment
2025-04-01 10:57:40 +08:00
Soulter
614f238a61 Merge pull request #1072 from zhx8702/feat-add-plugin-md-dialog
feat: 安装完插件后自动弹出插件仓库 README 对话框
2025-04-01 10:56:24 +08:00
zhx
dec91950bc feat: 安装完插件后自动弹出插件仓库 README 对话框 2025-04-01 10:04:04 +08:00
anka
6cef9c23f0 bug fix: #1074 修改最多携带对话数量时出现bug 2025-03-31 22:41:23 +08:00
anka
3f568bf136 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-31 22:32:40 +08:00
anka
5484b421ce perf: 增加部分注释 2025-03-31 22:30:43 +08:00
Soulter
02f21e07d3 📦 release: v3.5.1 2025-03-31 10:59:32 +08:00
Soulter
fff1f23a83 Update README.md 2025-03-31 00:57:23 +08:00
Soulter
a056ec0d38 Merge pull request #1065 from AstrBotDevs/perf-openai-source-balance
🎈 perf: OpenAI sources supports api key load balance(random)
2025-03-30 22:53:27 +08:00
Soulter
2eb9e5dde3 perf: 添加重试等待 2025-03-30 22:51:34 +08:00
渡鸦95676
627d2a4701 新增重试间隔 2025-03-30 22:33:21 +08:00
Soulter
76895fe86d chore: improve variable names 2025-03-30 22:12:34 +08:00
Soulter
64c3c85780 Merge pull request #1056 from Raven95676/master
perf: 优化无对话情况下设置人格的反馈;若禁用提供商,自动切换到另一个可用的提供商
2025-03-30 22:10:23 +08:00
Soulter
7288348857 🎈 perf: OpenAI sources supports api key load balance(random) 2025-03-30 22:00:45 +08:00
Soulter
62e73299b1 🐛 fix: forcely write shared preference data
Note: this is a fast fix for recent feedbacks, we'll improve its performance.
2025-03-30 21:33:41 +08:00
Raven95676
fe76c41ed8 perf: 若禁用提供商,自动切换到另一个可用的提供商 2025-03-30 15:18:48 +08:00
Raven95676
1a92edf8be perf: 优化无对话情况下设置人格的反馈 2025-03-30 14:38:40 +08:00
Soulter
b63b606a4e docs: 推荐使用 uv 进行手动部署 2025-03-30 10:39:14 +08:00
Soulter
8e2ef3d22b Merge pull request #1050 from advent259141/master
回复空@功能的修复
2025-03-30 00:15:26 +08:00
Gao Jinzhe
c6c4a32283 Add files via upload 2025-03-29 22:37:18 +08:00
Soulter
b70b3b158e feat: 支持 gemini-2.0-flash-exp-image-generation 对图片模态的输入 #1017 2025-03-29 20:51:27 +08:00
Soulter
3d59ab8108 fix: conversation and tool use page refresh 404 2025-03-29 19:17:56 +08:00
Soulter
b6c3089510 🎈 perf: 优化空 at 回复 2025-03-29 19:09:35 +08:00
Soulter
bd92aac280 feat: 支持 /llm 指令快捷启停 LLM 功能 #296 2025-03-29 18:31:07 +08:00
Soulter
5299e802e9 Merge pull request #1046 from AstrBotDevs/feat-docker-embedded-ffmpeg
docker 镜像提供内置 ffmpeg
2025-03-29 17:53:40 +08:00
Soulter
8e5a57d7dd Merge pull request #1045 from Raven95676/master
在lifecycle新增插件资源清理逻辑
2025-03-29 17:53:16 +08:00
Soulter
beaa324fb6 Merge pull request #1012 from Zhenyi-Wang/master
feat: gewechat client增加获取通讯录列表接口
2025-03-29 17:51:35 +08:00
Soulter
79e64fe206 Merge pull request #1011 from left666/left666
feat(core): 在 MessageChain 类中添加 at 和 at_all 方法
2025-03-29 17:50:55 +08:00
Soulter
93f525e3fe 🎈 perf: edge tts 支持使用代理;移除了一些不需要的方法 2025-03-29 17:48:22 +08:00
Soulter
aacb803c64 Merge pull request #999 from Futureppo/master
部分api获取不到model导致key泄露,使用正则表达式过滤掉key内容
2025-03-29 17:43:10 +08:00
Soulter
8a0665b222 🎈 feat: 更新 Dockerfile,添加 Node.js 支持并优化依赖安装 2025-03-29 17:42:31 +08:00
Soulter
20e41a7f73 🐛 fix: newgroup 指令名显示错误 2025-03-29 17:42:31 +08:00
Soulter
93a1699a35 Update README.md 2025-03-29 17:42:31 +08:00
Soulter
c33c07e4af Update README.md 2025-03-29 17:42:31 +08:00
Soulter
c7484d0cc9 Update README.md 2025-03-29 17:42:31 +08:00
Soulter
fb85a7bb35 feat: add demo mode 2025-03-29 17:42:31 +08:00
Soulter
42ff9a4d34 Update README.md 2025-03-29 17:42:31 +08:00
Soulter
005e9eae7c 🐛 fix: 插件更新时没有正确应用加速地址 2025-03-29 17:42:31 +08:00
Soulter
3e325debcc Update README.md 2025-03-29 17:42:31 +08:00
Soulter
a221de9a2b 🐛 fix: 修复 LLM 响应后事件钩子无法生效的问题 2025-03-29 17:42:31 +08:00
Soulter
32b0cc1865 Update README.md 2025-03-29 17:42:31 +08:00
Soulter
bbf85f8a12 🐛 fix: remove error logging for empty result and refresh extensions after upload 2025-03-29 17:42:31 +08:00
Soulter
67a0172b28 📦 release: v3.5.0 2025-03-29 17:42:31 +08:00
zhx
fb19d4d45b fix: install_plugin_from_file 方法load传参数改为文件名 2025-03-29 17:42:31 +08:00
Soulter
a156b1af14 feat: 支持通过指令下载插件 /plugin get 2025-03-29 17:42:31 +08:00
Soulter
a604b4943c 🎈 perf: 优化新版本时的信息显示 2025-03-29 17:42:31 +08:00
pre-commit-ci[bot]
3f0b6435d9 🎈 auto fixes by pre-commit hooks 2025-03-29 17:42:31 +08:00
Gao Jinzhe
e0f029e2cb Add files via upload 2025-03-29 17:42:31 +08:00
Soulter
89d3fd5fab 🎈 perf: 优化 WebUI 对话数据库中文历史检索 2025-03-29 17:42:31 +08:00
Soulter
a38b00be6b 🐛 fix: 修复部分可能形成 SQL 注入的风险 2025-03-29 17:42:31 +08:00
Futureppo
0e8d52b591 :ballon: feat: 使用正则表达式过滤掉 /model 可能暴露的 api_key
Squashed:

更新正则表达式

🎈 auto fixes by pre-commit hooks

Update main.py

Update main.py

chore: bugfixes
2025-03-29 17:40:48 +08:00
Soulter
298c77740d feat: docker 镜像提供内置 ffmpeg #979 2025-03-29 17:26:57 +08:00
Raven95676
c681aae8ee 修复日志问题 2025-03-29 17:25:38 +08:00
Raven95676
faef98b089 在lifecycle新增插件资源清理逻辑 2025-03-29 17:07:12 +08:00
Soulter
84a3e0a30b 🎈 feat: 更新 Dockerfile,添加 Node.js 支持并优化依赖安装 2025-03-29 16:36:02 +08:00
Soulter
69bd553ce0 Merge pull request #1035 from AstrBotDevs/fix-1034-bug
🐛 fix: groupnew 指令名显示错误
2025-03-28 23:46:30 +08:00
Soulter
fd0c0f8975 🐛 fix: newgroup 指令名显示错误 2025-03-28 23:45:19 +08:00
Zhenyi-Wang
860ceb06b4 Merge branch 'Soulter:master' into master 2025-03-28 21:27:25 +08:00
anka
ecf501bf72 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-28 19:04:35 +08:00
Soulter
81a2ed1e25 Update README.md 2025-03-28 18:20:33 +08:00
Soulter
76ab28338a Update README.md 2025-03-28 13:24:41 +08:00
Soulter
9a56c9630f Update README.md 2025-03-28 13:23:29 +08:00
anka
53b9497c18 perf: 增加部分注释 2025-03-27 21:32:38 +08:00
Soulter
750b16b6ee feat: add demo mode 2025-03-27 15:54:23 +08:00
anka
0ee3e0779a Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-27 15:21:04 +08:00
pre-commit-ci[bot]
333c2d9299 🎈 auto fixes by pre-commit hooks 2025-03-27 03:21:43 +00:00
Zhenyi Wang
ad37ff5048 feat: gewechat client增加获取通讯录列表接口 2025-03-27 11:17:52 +08:00
pre-commit-ci[bot]
33f86f3bde 🎈 auto fixes by pre-commit hooks 2025-03-27 02:56:55 +00:00
Soulter
8acb969a49 Update README.md 2025-03-27 10:39:18 +08:00
left666
b74b5933b8 feat(core): 在 MessageChain 类中添加 at 和 at_all 方法
- 新增 at 方法,用于添加 At 消息到消息链中
- 新增 at_all 方法,用于添加 AtAll 消息到消息链中
2025-03-27 10:30:19 +08:00
Soulter
681c556b7e 🐛 fix: 插件更新时没有正确应用加速地址 2025-03-27 10:04:40 +08:00
anka
1746684e52 perf: 修改部分注释 2025-03-26 23:52:03 +08:00
Soulter
0b93d06555 Update README.md 2025-03-26 20:51:53 +08:00
anka
8a8b8c7c27 Merge remote-tracking branch 'origin/master' into anka-dev 2025-03-26 17:59:53 +08:00
anka
6b6577006d perf: 格式化 2025-03-26 17:59:30 +08:00
Soulter
23ee5e81c9 🐛 fix: 修复 LLM 响应后事件钩子无法生效的问题 2025-03-26 17:56:55 +08:00
Soulter
483f55e4b1 Update README.md 2025-03-26 16:16:03 +08:00
Soulter
1bb1bc2553 🐛 fix: remove error logging for empty result and refresh extensions after upload 2025-03-26 15:43:56 +08:00
Soulter
a4e4e36f94 📦 release: v3.5.0 2025-03-26 15:30:09 +08:00
Soulter
6849415812 Merge pull request #996 from zhx8702/fix-star-manager
fix: install_plugin_from_file 方法load传参数改为文件名
2025-03-26 15:26:53 +08:00
zhx
86f6cb038e fix: install_plugin_from_file 方法load传参数改为文件名 2025-03-26 15:06:33 +08:00
Soulter
7480a1d6ce feat: 支持通过指令下载插件 /plugin get 2025-03-26 14:33:45 +08:00
Soulter
3cd10117dd 🎈 perf: 优化新版本时的信息显示 2025-03-26 14:14:01 +08:00
Soulter
0caf19d390 Merge pull request #937 from advent259141/master
将对只有一个 @ 的消息内容的处理改成调用llm回复
2025-03-26 13:54:43 +08:00
anka
5c14ebb049 Merge remote-tracking branch 'origin/master' into anka-dev 2025-03-26 13:53:21 +08:00
anka
9717a736b1 perf: 更新部分描述 2025-03-26 13:50:54 +08:00
Soulter
9c9ab50d1a 🎈 perf: 优化 WebUI 对话数据库中文历史检索 2025-03-26 13:50:11 +08:00
Soulter
d4bcb8174e 🐛 fix: 修复部分可能形成 SQL 注入的风险 2025-03-26 13:41:18 +08:00
anka
9e7fe773bd perf: 更新部分注释 2025-03-26 11:14:46 +08:00
Soulter
aca18fab0f feat: 优化配置文件中的提示信息,增强可读性 2025-03-26 00:56:51 +08:00
Soulter
691de01b79 feat: 支持设置最多携带对话数量 2025-03-26 00:46:15 +08:00
Soulter
3383f15142 Merge pull request #988 from Soulter/NiceAir/master
 feat: Update UI elements and improve layout in various components
2025-03-25 23:17:11 +08:00
Soulter
84c1593889 feat: Update UI elements and improve layout in various components 2025-03-25 21:52:15 +08:00
Soulter
3c80fa1e33 Update README.md 2025-03-25 21:31:23 +08:00
Soulter
06b16a1deb Merge pull request #983 from Soulter/feat-conversation-webui-mgr
 支持 WebUI 对话管理
2025-03-25 21:26:00 +08:00
Soulter
4c4246fb09 Merge pull request #982 from NiceAir/master
添加对gewe的表情包、引用消息、视频的支持
2025-03-25 21:25:00 +08:00
Soulter
364be1e9f6 🐛 fix: Handle missing defusedxml dependency for Gewechat message parsing 2025-03-25 21:21:38 +08:00
NiceAir
f959ed71aa feat: Gewechat 支持表情包、引用消息、视频
Co-authored-by: Soulter <905617992@qq.com>
2025-03-25 21:00:12 +08:00
anka
5c4326c302 perf: 部分详细注释, 符合PEP8标准 2025-03-25 20:53:23 +08:00
Soulter
125fc3a622 feat: 支持 WebUI 对话管理 2025-03-25 19:44:46 +08:00
Soulter
6b9e785db3 Merge pull request #968 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-25 15:03:39 +08:00
Soulter
25d34e9a43 Merge pull request #974 from zhx8702/feat-webui-add-search-keys
feat: 插件市场列表卡片过滤条件提出变量保持一致
2025-03-25 15:03:09 +08:00
Soulter
457d4aa1dc Merge pull request #976 from Raven95676/master
Improves Telegram adapter termination
2025-03-25 15:01:04 +08:00
Raven95676
ff0c0992ff Improves Telegram adapter termination 2025-03-25 14:46:20 +08:00
Soulter
d379e012c4 🐛 fix: telegram /start issue #751 2025-03-25 14:03:46 +08:00
zhx
151fff26fd feat: 插件市场列表卡片过滤条件提出变量保持一致 2025-03-25 13:50:16 +08:00
Soulter
3d0d561215 Update compose.yml 2025-03-25 13:24:37 +08:00
Soulter
22d586ed7b Update compose.yml 2025-03-25 13:24:19 +08:00
Soulter
6dc19b29e8 🐛 fix: remove redundant validation call in config validation function #901 2025-03-25 12:56:48 +08:00
Soulter
50975a87d4 🐛 fix: handle message sending failures with error logging 2025-03-25 12:34:43 +08:00
Soulter
ce721d9f0f 🐛 fix: platform adapter server blocks ctrl+c 2025-03-25 11:31:46 +08:00
Soulter
20510a33f7 feat: improve pyproject and use uv as package mgr 2025-03-25 11:07:20 +08:00
pre-commit-ci[bot]
3abd9c8763 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.11.0 → v0.11.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.0...v0.11.2)
2025-03-24 17:08:12 +00:00
Soulter
e9eff7420b feat: 更加完善和美观的 本地 Markdown 渲染 2025-03-25 00:56:19 +08:00
Soulter
64c250c9d8 🎈perf: 优化可能的 conversation 为 None 的问题 2025-03-25 00:06:25 +08:00
Soulter
8047f82bfd 🎈perf: 优化删除插件目录的逻辑,抛出异常细节;完善 mcp 未安装时的提示 2025-03-24 23:07:56 +08:00
Soulter
af6467fb3d Merge pull request #962 from zhx8702/feat-webui-add-double-confirm
feat: 删除插件添加二次确认,插件列表添加非空判断
2025-03-24 23:01:43 +08:00
zhx
3ff1664aec feat: 删除多余代码 2025-03-24 20:27:05 +08:00
zhx
34ea2b44b8 Merge remote-tracking branch 'upstream/master' into feat-webui-add-double-confirm 2025-03-24 19:42:47 +08:00
Soulter
6c8d851109 Merge pull request #955 from Raven95676/master
Telegram适配器消息处理功能增强
2025-03-24 18:10:51 +08:00
Soulter
d678299a74 Merge branch 'master' into master 2025-03-24 18:10:27 +08:00
Soulter
7aed0db2b6 Merge pull request #951 from IGCrystal/master
fix: fix SSLCertVerificationError
2025-03-24 18:05:49 +08:00
Soulter
0355524345 Merge branch 'master' into master 2025-03-24 17:58:00 +08:00
Soulter
0a43e4672e style: format codes 2025-03-24 17:57:28 +08:00
zhx
71e0ccdfec feat: 删除插件添加二次确认,插件列表添加非空判断 2025-03-24 16:41:54 +08:00
冰苷晶
1df33ac3c8 fix: fix error 2025-03-24 13:28:14 +08:00
pre-commit-ci[bot]
7334090ac1 🎈 auto fixes by pre-commit hooks 2025-03-24 05:20:37 +00:00
冰苷晶
6b0f044198 fix: fix other errors 2025-03-24 13:20:05 +08:00
pre-commit-ci[bot]
ddf54c9cf8 🎈 auto fixes by pre-commit hooks 2025-03-24 04:32:21 +00:00
IGCrystal
7c64e184e2 Merge branch 'Soulter:master' into master 2025-03-24 12:32:16 +08:00
渡鸦95676
a904db033c Merge branch 'Soulter:master' into master 2025-03-24 12:19:17 +08:00
渡鸦95676
b234856b02 Remove unused variable
移除以通过ruff检查
在Ubuntu24.04LTS中,移除未见对现有功能的影响
2025-03-24 11:36:46 +08:00
Soulter
89d51d2afc 🎈 perf: config UI 2025-03-24 11:36:38 +08:00
Soulter
37cb9678e9 Merge pull request #826 from XuYingJie-cmd/master
新增了关于gewe发送视频的功能
2025-03-24 11:25:24 +08:00
pre-commit-ci[bot]
0500ff333a 🎈 auto fixes by pre-commit hooks 2025-03-24 02:50:28 +00:00
Raven95676
08528510ef Fix incorrect handling of reply messages within topics 2025-03-24 10:41:33 +08:00
Raven95676
ddbd03dc1e Adds sticker handling in Telegram adapter 2025-03-24 10:40:20 +08:00
Soulter
ade87f378a 🎈 perf: UI 优化 2025-03-24 00:32:40 +08:00
冰苷晶
4db14b905f fix: fix error 2025-03-23 23:40:06 +08:00
pre-commit-ci[bot]
b669b31451 🎈 auto fixes by pre-commit hooks 2025-03-23 15:07:22 +00:00
冰苷晶
1cb2b62f81 fix: fix error 2025-03-23 23:02:34 +08:00
Soulter
e5828713cf 🎈 perf: improve ChatPage and ConfigPage UI 2025-03-23 22:57:02 +08:00
冰苷晶
d10cb84068 fix: fix SSLCertVerificationError 2025-03-23 22:55:07 +08:00
Soulter
4222f8516f Merge pull request #844 from AraragiEro/mcp_adapt
支持 MCP 服务并优化函数调用流程
2025-03-23 22:35:35 +08:00
Soulter
7f998c7611 chore: remove useless print output 2025-03-23 22:28:00 +08:00
Soulter
db46000337 🎨 style: format codes 2025-03-23 22:22:11 +08:00
Soulter
1aac8d8041 feat: 适配完整的 function-calling 流程 2025-03-23 22:21:47 +08:00
Soulter
c59c8e05f7 🐛 fix: tools result 2025-03-23 17:03:18 +08:00
Soulter
4942d0a629 feat: 在工具使用页面添加函数调用信息提示和链接功能 2025-03-23 17:00:38 +08:00
Soulter
873b7715f4 🎈 perf: 优化 MCP Client 异步 Event 管理 2025-03-23 16:51:28 +08:00
pre-commit-ci[bot]
98e7ed6920 🎈 auto fixes by pre-commit hooks 2025-03-23 08:34:05 +00:00
Soulter
046f5e645e feat: 完善 MCP 管理和实现 WebUI MCP 相关的页面 2025-03-23 16:33:44 +08:00
pre-commit-ci[bot]
f5e5a7094c 🎈 auto fixes by pre-commit hooks 2025-03-23 06:39:13 +00:00
Gao Jinzhe
154125fee6 Add files via upload 2025-03-23 14:35:44 +08:00
pre-commit-ci[bot]
9f8e960ebe 🎈 auto fixes by pre-commit hooks 2025-03-23 03:31:20 +00:00
Soulter
4179b0be0a chore: 优化注解格式和 requirements.txt 2025-03-23 11:31:10 +08:00
Soulter
28bafa38db Merge branch 'master' into mcp_adapt 2025-03-23 11:01:44 +08:00
Soulter
b07552565e Merge pull request #926 from Soulter/perf-graceful-shutdown
支持所有消息平台的优雅退出
2025-03-23 10:56:56 +08:00
Soulter
c4427471d2 🎨 style: format codes 2025-03-23 00:25:26 +08:00
Soulter
08f81c6784 🐛 fix: 修复图片没有被存储到上下文中的问题 2025-03-23 00:23:42 +08:00
Soulter
a471e98aca 🐛 fix: Telegram 下无法识别图片描述(Caption) #910 2025-03-23 00:23:01 +08:00
Soulter
75a8fcc8a0 🐛 fix: 修复 Telegram 下非默认群组话题引用消息异常 #906 2025-03-22 23:39:21 +08:00
Soulter
46ef76c168 feat: 支持消息平台的热重载 2025-03-22 19:54:54 +08:00
Soulter
66637446c9 Merge remote-tracking branch 'origin/master' into perf-graceful-shutdown 2025-03-22 19:26:35 +08:00
Soulter
21efeb888a Merge pull request #904 from LunarMeal/master
新增了newgroup指令
2025-03-22 19:18:06 +08:00
Soulter
a4ee8b5322 Merge remote-tracking branch 'origin/master' into LunarMeal/master 2025-03-22 19:17:12 +08:00
Soulter
36519ac47e 🐛 fix: groupnew 设置为管理员指令 2025-03-22 19:14:58 +08:00
Soulter
3f514fceca 🎨 style: format codes 2025-03-22 19:07:47 +08:00
pre-commit-ci[bot]
c2249fdfac 🎈 auto fixes by pre-commit hooks 2025-03-22 11:06:42 +00:00
Soulter
c610719a44 feat: 为各平台适配器支持优雅关闭 2025-03-22 19:02:49 +08:00
Soulter
36a6c2461a 🐛 fix: 修复 Telegram Topic 群组下LLM 上下文及主动消息混乱的问题 #908 2025-03-22 18:15:43 +08:00
Soulter
c29f22c39e Update PLUGIN_PUBLISH.yml 2025-03-22 15:51:35 +08:00
Soulter
30d3062944 🎈 perf: 优化钉钉在配置错误之后堵塞整个线程的问题 #885
a.k.a 帮钉钉擦屁股
2025-03-22 15:44:42 +08:00
Soulter
69ba75abf4 Update README.md 2025-03-22 01:26:03 +08:00
Soulter
e4d486fec5 docs: 宝塔面板部署方式 2025-03-22 00:42:04 +08:00
Soulter
f242144dcf 更新 README.md 2025-03-21 19:21:35 +08:00
Soulter
02dee2d664 🎈 perf: add error handling for missing pyffmpeg library in video sending functionality 2025-03-21 16:51:23 +08:00
Soulter
a3dd2c3069 Merge remote-tracking branch 'origin/master' into XuYingJie-cmd/master 2025-03-21 16:49:15 +08:00
Soulter
a23425e8aa Merge pull request #781 from Moyuyanli/master
添加gewe的群相关操作
2025-03-21 16:31:10 +08:00
Moyuyanli
be79ddc9a3 fix:去掉跟post_text功能相同的接口方法 2025-03-21 16:24:31 +08:00
Soulter
7d71015e8c Update README.md 2025-03-21 16:12:25 +08:00
Soulter
ad54549b51 Update README.md 2025-03-21 15:58:40 +08:00
Soulter
6cf032a164 Update compose.yml 2025-03-21 11:06:22 +08:00
Soulter
6390d796ac Update compose.yml 2025-03-21 11:05:44 +08:00
Soulter
98b8411905 Update compose.yml 2025-03-21 10:53:09 +08:00
LunarMeal
ddf1029afa Merge branch 'master' of https://github.com/LunarMeal/AstrBot 2025-03-20 22:53:29 +08:00
LunarMeal
1effbc5cc9 fix 2025-03-20 22:53:21 +08:00
pre-commit-ci[bot]
414b645e9f 🎈 auto fixes by pre-commit hooks 2025-03-20 14:42:37 +00:00
LunarMeal
398c76f496 新增了newgroup指令 2025-03-20 22:39:49 +08:00
Soulter
1bc456dd95 🎈 perf: 改善一些术语描述 2025-03-20 20:31:36 +08:00
Soulter
2e8421884e Merge pull request #864 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-20 20:23:45 +08:00
Soulter
70d9b193ac 🐛 fix: 修复私聊下 get_group 的一些问题 2025-03-20 20:18:20 +08:00
Moyuyanli
b49c11004a fix:还原回原来的依赖信息 2025-03-20 19:57:35 +08:00
Soulter
34843eea90 🎨 style: format codes 2025-03-20 18:07:24 +08:00
pre-commit-ci[bot]
2d6d7f31e8 🎈 auto fixes by pre-commit hooks 2025-03-20 10:06:11 +00:00
Soulter
7a24cbff1c feat: 支持 aiocqhttp 适配器下的获取群消息 2025-03-20 18:05:44 +08:00
pre-commit-ci[bot]
1e7eb2cf1c 🎈 auto fixes by pre-commit hooks 2025-03-20 09:21:32 +00:00
Soulter
361256e016 chore: 添加了一些 gewechat client 的注释 2025-03-20 17:20:32 +08:00
Soulter
8838dbd003 🎨 style: format codes 2025-03-20 16:54:27 +08:00
pre-commit-ci[bot]
13a95e1f2b 🎈 auto fixes by pre-commit hooks 2025-03-20 08:42:40 +00:00
Soulter
1aaa451a3e Merge branch 'master' into Moyuyanli/master 2025-03-20 16:42:13 +08:00
Soulter
cbba81e54d 🐛 fix: 无法接收图片 aiocqhttp 2025-03-20 16:03:41 +08:00
Soulter
370868dfac 🎈 perf: 消息平台和配置提供商配置页中,自动更新旧的配置,添加新的配置项 2025-03-20 13:22:49 +08:00
Soulter
77f692aae2 🎈 perf: 配置项显示优化 2025-03-20 13:17:27 +08:00
Soulter
9318e205ea feat: 阿里云百炼应用支持 RAG 应用 #878 2025-03-20 13:17:06 +08:00
Soulter
ebcc717c19 🎈 perf: Dify 下支持更多类型的图片输入及提高代码复用性 #893
🐛 fix: 修复飞书下无法进行图片输入的问题
2025-03-20 11:21:45 +08:00
Soulter
4c16b564ee 🎈 perf: 忽略微信团队消息 #859 2025-03-19 01:09:01 +08:00
Soulter
e2283d1453 🐛 fix: 修复 dify 下某些修改了 LLM 响应的插件可能不生效的问题 #876 2025-03-19 01:05:28 +08:00
Soulter
d891801c5a v3.4.39 2025-03-18 22:43:35 +08:00
Soulter
de75386944 🎈 perf: 登录后检查默认密码和弹出修改警告 2025-03-18 22:41:33 +08:00
Soulter
82dc37de50 style: format codes 2025-03-18 22:21:47 +08:00
Soulter
b6fa7f62dc chore: 添加安全提示信息 2025-03-18 22:18:01 +08:00
Soulter
f9e0a95c5e chore: 默认地址改回 0.0.0.0 2025-03-18 22:15:22 +08:00
pre-commit-ci[bot]
b2c6e12647 🎈 auto fixes by pre-commit hooks 2025-03-17 17:10:06 +00:00
pre-commit-ci[bot]
caffb83780 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.9.10 → v0.11.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.10...v0.11.0)
2025-03-17 17:09:59 +00:00
Soulter
8882cb5479 v3.4.38 2025-03-18 00:54:51 +08:00
Soulter
75dace2dee 🎈 perf: 优化配置页的显示 2025-03-18 00:16:47 +08:00
Soulter
ad6487d042 🐛 fix: 修复部分指令可能造成的配置类型问题 2025-03-17 23:44:04 +08:00
Soulter
a91604e8ab Merge pull request #853 from IGCrystal/master
🎈 perf: 优化了iframe窗口,新增跳转按钮
2025-03-17 23:25:26 +08:00
Soulter
c364f7c643 🎈 perf: Dify 下当只有图片输入时的默认 prompt #837 2025-03-17 23:17:07 +08:00
Soulter
53435ba184 🐛 fix: 修复 model_config 中自定义的配置项(如温度)类型自动变回 string #854 2025-03-17 23:11:57 +08:00
Soulter
25f8d5519b 🐛 fix: LLOnebot 合并消息转发错误 #842 2025-03-17 22:42:48 +08:00
Moyuyanli
2e4fef6c66 feat:添加消息记录器 2025-03-17 16:02:55 +08:00
冰苷晶
80b2b7dc00 🎈 perf: 优化了iframe窗口 2025-03-16 21:35:30 +08:00
Alero
8585cd8e21 修复codecheck 2025-03-15 20:26:17 +08:00
Alero
9fa2a7eeea 修复codecheck 2025-03-15 20:24:36 +08:00
pre-commit-ci[bot]
2d1f74228d 🎈 auto fixes by pre-commit hooks 2025-03-15 12:10:17 +00:00
Alero
3d6f7aa0e1 修复codecheck 2025-03-15 20:09:49 +08:00
pre-commit-ci[bot]
3dea60366a 🎈 auto fixes by pre-commit hooks 2025-03-15 11:54:09 +00:00
Alero
d4d9a1df4c feat:新增MCP服务支持并优化工具调用逻辑
引入MCP客户端支持,增加mcp_server.json配置样例,完善工具描述生成及调用逻辑以支持MCP服务工具功能。同时调整相关逻辑以区分本地工具与MCP工具的调用方式,提升扩展性和灵活性。
2025-03-15 19:47:06 +08:00
Soulter
7d6975fd31 Merge pull request #832 from IGCrystal/master
🎈 perf: 优化iframe窗口,加入了关闭按钮
2025-03-15 14:25:16 +08:00
IGCrystal
08be52ed17 Merge branch 'Soulter:master' into master 2025-03-15 12:05:27 +08:00
邹永赫
682a7700c2 Merge pull request #835 from zouyonghe/master
修改注册函数工具时的打印信息
2025-03-15 12:20:32 +09:00
pre-commit-ci[bot]
9d87009216 🎈 auto fixes by pre-commit hooks 2025-03-15 03:16:51 +00:00
邹永赫
ef86838f62 修改注册函数工具时的打印信息 2025-03-15 12:15:05 +09:00
Soulter
35468233f8 🎈 perf: supports for customizing webui host, wecom webhook server host, qq official webhook server host #821 2025-03-15 01:21:36 +08:00
Soulter
26e229867d 🐛fix: 可能的QQ平台回复消息带有末尾空白的问题 #822 2025-03-15 00:57:17 +08:00
Soulter
3a1578b3c6 feat: 支持 Dify 文件、图片、视频、音频输出。#819 2025-03-15 00:51:32 +08:00
冰苷晶
d5e3d2cbbc 🎈 perf: 优化iframe窗口,加入了关闭按钮 2025-03-14 20:23:15 +08:00
Moyuyanli
c095248176 Merge remote-tracking branch 'origin/master' 2025-03-14 18:30:42 +08:00
Moyuyanli
44601c8954 fix:修复gewe的ModContacts消息类型 2025-03-14 18:30:27 +08:00
Soulter
135dbb8f07 style: clean codes 2025-03-14 18:02:00 +08:00
pre-commit-ci[bot]
c95682a0c7 🎈 auto fixes by pre-commit hooks 2025-03-14 09:11:21 +00:00
Moyuyanli
d177b9f7fa feat:添加主动添加好友事件 2025-03-14 17:11:10 +08:00
徐英杰
9b57615d94 新增了关于gewe发送视频的功能 2025-03-14 16:19:41 +08:00
Soulter
c03f3eacd1 Update README.md 2025-03-13 23:03:36 +08:00
Soulter
a26e395932 Merge pull request #817 from Soulter/feat-parse-reply
[Feature] 添加了 LLM 对消息平台引用回复内容的感知
2025-03-13 21:06:44 +08:00
Soulter
0870b87c96 🐛 fix: 获取引用消息失败时没有将引用消息段加入消息链 2025-03-13 20:59:52 +08:00
Soulter
b52a44a7dd 🎨 stype: format codes 2025-03-13 20:44:08 +08:00
Soulter
0a290aafef Merge pull request #815 from diudiu62/perf-gewechat
微信有未处理的消息类型,导致控制台打印太多的日志
2025-03-13 20:39:39 +08:00
Soulter
9014d4c410 🎨 style: format codes 2025-03-13 20:36:41 +08:00
pre-commit-ci[bot]
60e58b4f5f 🎈 auto fixes by pre-commit hooks 2025-03-13 09:52:03 +00:00
Soulter
620e74a6aa Merge branch 'master' into feat-parse-reply 2025-03-13 17:51:12 +08:00
Soulter
efa287ed35 feat: 支持 LLM 对引用消息的感知 #783 2025-03-13 17:40:28 +08:00
Soulter
a24eb9d9b0 🏗 refactor: clean up AstrBotConfig component markup for improved readability 2025-03-13 17:02:58 +08:00
Soulter
bd3dab8aae 🐛 fix: 插件管理的插件简介太长 “帮助”“操作”图标不显示 #790 2025-03-13 17:02:58 +08:00
Soulter
4fe1ebaa5b 🏗 refactor: improve styling and layout of AstrBotConfig component for enhanced readability 2025-03-13 17:02:58 +08:00
Soulter
c5e944744b 🏗 refactor: enhance ConfigPage layout and styling for better user experience 2025-03-13 17:02:58 +08:00
Soulter
0c396181f7 🏗 refactor: 配置页样式重写 2025-03-13 17:02:58 +08:00
Soulter
0034474219 🐛 fix: sent message to wrong topic in topic group #801 2025-03-13 17:02:58 +08:00
shuiping233
8136ad8287 修复命令参数报错信息无法发送至qq官方机器人平台的bug 2025-03-13 17:02:58 +08:00
Soulter
681940d466 🐛 fix: 修复重载插件时函数工具可能多次家在的问题 2025-03-13 17:02:58 +08:00
Soulter
16488506e8 🐛 fix: 修复部分情况下文件无法上传到 Telegram 群组的问题 #601 2025-03-13 17:02:58 +08:00
邹永赫
122fccc041 修复无法发送非嵌套的转发消息的问题 2025-03-13 17:02:58 +08:00
邹永赫
9d0ad35403 支持嵌套转发,里层包含多条信息 2025-03-13 17:02:58 +08:00
邹永赫
f9ec97e026 支持嵌套转发 2025-03-13 17:02:58 +08:00
Soulter
95495a2647 🏗 refactor: clean up AstrBotConfig component markup for improved readability 2025-03-13 16:40:59 +08:00
Soulter
e3310a605c 🐛 fix: 插件管理的插件简介太长 “帮助”“操作”图标不显示 #790 2025-03-13 16:36:35 +08:00
Soulter
b55719bf28 🏗 refactor: improve styling and layout of AstrBotConfig component for enhanced readability 2025-03-13 15:59:20 +08:00
diudiu62
b957b51279 已知消息类型,没有业务处理,只是避免控制台打印太多的日志 2025-03-13 15:55:22 +08:00
Soulter
90bcfab369 🏗 refactor: enhance ConfigPage layout and styling for better user experience 2025-03-13 15:44:52 +08:00
Soulter
f8a8e30641 🏗 refactor: 配置页样式重写 2025-03-13 15:37:53 +08:00
Soulter
25cb98e7a7 🐛 fix: sent message to wrong topic in topic group #801 2025-03-13 13:02:22 +08:00
Soulter
03e1bb7cf9 Merge pull request #807 from shuiping233/fix-#806
修复命令参数报错信息无法发送至qq官方机器人平台的bug
2025-03-13 10:05:24 +08:00
Soulter
85dbb24f3a 🐛 fix: 修复重载插件时函数工具可能多次家在的问题 2025-03-12 23:37:24 +08:00
shuiping233
d817635782 修复命令参数报错信息无法发送至qq官方机器人平台的bug 2025-03-12 18:09:25 +08:00
Soulter
2f4f237810 🐛 fix: 修复部分情况下文件无法上传到 Telegram 群组的问题 #601 2025-03-12 14:14:45 +08:00
邹永赫
5ac94d810f Merge pull request #794 from zouyonghe/dev/nested-forward
修复无法发送非嵌套的转发消息的问题
2025-03-12 12:01:33 +09:00
邹永赫
39dc46dc25 修复无法发送非嵌套的转发消息的问题 2025-03-12 11:59:53 +09:00
邹永赫
0d9cf725f7 Merge pull request #792 from zouyonghe/dev/nested-forward
支持嵌套转发,里层包含多条信息
2025-03-12 11:17:16 +09:00
邹永赫
e55dbead5b 支持嵌套转发,里层包含多条信息 2025-03-12 11:14:54 +09:00
邹永赫
7d046e5b30 Merge pull request #788 from zouyonghe/dev/nested-forward
支持嵌套转发
2025-03-12 08:50:50 +09:00
邹永赫
8b4693cf66 支持嵌套转发 2025-03-12 08:39:54 +09:00
Soulter
a1172c9a82 feat: 支持解析回复消息 #783 2025-03-11 23:27:10 +08:00
Soulter
1ed2bd33f0 🐛 fix: 修复插件更新时显示未知更新的问题 2025-03-11 22:38:25 +08:00
Soulter
4c159bd0ba Merge pull request #785 from shuiping233/fix-qq-offical-image-upload-issue
修复了使用Image.fromBytes等包装的图片消息链无法通过qq官方机器人适配器发送的bug
2025-03-11 22:10:27 +08:00
Soulter
050654b2a9 🐛 fix: 修复 QQ 官方机器人适配器下发送base64图片消息段报错的问题。
Co-authored-by: shuiping233 <1944680304@qq.com>
2025-03-11 22:08:13 +08:00
Soulter
61b261e1b2 Merge pull request #780 from beat4ocean/master
fix: 修复gewechat平台用户本人发消息触发消息回复的bug
2025-03-11 21:55:44 +08:00
shuiping233
017b010206 修复了使用Image.fromBytes等包装的图片消息链无法通过qq官方机器人适配器发送的bug 2025-03-11 21:17:08 +08:00
pre-commit-ci[bot]
00f5189f58 🎈 auto fixes by pre-commit hooks 2025-03-11 09:16:43 +00:00
Moyuyanli
4a8309ed1f style:idea默认格式化了部分代码
feat:添加根据消息事件获取群信息的接口
2025-03-11 17:10:55 +08:00
Moyuyanli
76cfc31a1d feat:添加 Group 类型 2025-03-11 17:10:04 +08:00
Moyuyanli
d9ec434699 feat:gewe的client添加 添加好友接口
feat:gewe的client添加 获取群信息/群成员接口
feat:gewe的client添加 添加群成员为好友接口
2025-03-11 17:08:33 +08:00
Soulter
239f3c40be 🎈 perf: 优化 WebUI 边栏宽度 2025-03-11 16:11:34 +08:00
Soulter
09c8c6e670 🐛 fix: 修复 aiocqhttp 下可能的设置管理员无效的问题 2025-03-11 15:52:30 +08:00
beat4ocean
7e4ad01c94 Merge branch 'Soulter:master' into master 2025-03-11 15:52:23 +08:00
beat4ocean
ed98e269ef Merge remote-tracking branch 'origin/master' 2025-03-11 15:48:44 +08:00
beat4ocean
b47d63334f fix: 修复gewechat平台用户本人发消息触发消息回复的bug 2025-03-11 15:48:28 +08:00
Soulter
5e2a3a5aea fix: 修复部分情况下 EdgeTTS 无法使用的问题
Co-authored-by: 需要哦 <2687427560@qq.com>
2025-03-11 15:29:51 +08:00
Soulter
1a7eb21fc7 Revert "🐛 fix: 修复 gewechat 部分场景下下载图片报错 #700"
This reverts commit c38fa77ce6.
2025-03-11 14:54:41 +08:00
Soulter
834a51cdc9 🐛 fix: 修复 OpenAI TTS API TypeError 报错 #755 2025-03-11 14:30:59 +08:00
Soulter
1b69d99c06 🐛 fix: 修复更新插件后插件重载不完全的问题 2025-03-11 14:20:24 +08:00
Soulter
ad189933c6 Merge pull request #775 from roeseth/master
update compose.yml to mount system time and tz
2025-03-11 12:49:38 +08:00
Soulter
9d86ff32de Merge pull request #774 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-11 11:40:57 +08:00
Soulter
278bb57a58 Merge pull request #772 from beat4ocean/master
fix: 修复个人微信非第一次登陆情况,已记录gewechat的appid失效设备不存在导致无法重新登陆个人微信的bug
2025-03-11 11:40:07 +08:00
pre-commit-ci[bot]
0ba494e0ba 🎈 auto fixes by pre-commit hooks 2025-03-11 02:11:25 +00:00
roeseth
8b247054bb update compose.yml to mount system time and tz 2025-03-10 19:07:45 -07:00
pre-commit-ci[bot]
7c5c8e4e0d 🎈 auto fixes by pre-commit hooks 2025-03-11 00:55:01 +00:00
beat4ocean
ad106a27f3 Merge branch 'Soulter:master' into master 2025-03-11 08:54:55 +08:00
beat4ocean
9d6f61b49e fix: 修复非第一次登陆情况,已记录的gewechat的appid失效设备不存在导致无法重新登陆的bug 2025-03-11 08:48:37 +08:00
pre-commit-ci[bot]
02368954a0 🎈 auto fixes by pre-commit hooks 2025-03-10 17:09:25 +00:00
pre-commit-ci[bot]
b477a35a01 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.9.9 → v0.9.10](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.9...v0.9.10)
2025-03-10 17:09:18 +00:00
Soulter
16622887de perf: 在调用插件异常时更完整的报错信息 2025-03-11 00:47:37 +08:00
Soulter
9059d1fb17 feat: 支持在对话隔离情况下可以将群聊加入白名单 #746 2025-03-11 00:34:29 +08:00
Soulter
df2b008d82 Merge pull request #744 from roeseth/fix-local-timezone
Use system local time zone instead of hardcoded UTC+8
2025-03-11 00:21:43 +08:00
Soulter
0da871efd0 chore: 日志完善 2025-03-10 23:58:42 +08:00
Soulter
1c55349f81 fix: 钉钉 webui 文档 2025-03-10 23:58:42 +08:00
Soulter
9309fa1e81 修复fishaudio默认baseurl不可用的问题 2025-03-10 01:32:26 +08:00
Soulter
5996189f91 Update README.md 2025-03-09 22:25:45 +08:00
Soulter
bd2b984bfb v3.4.37 2025-03-09 22:14:23 +08:00
pre-commit-ci[bot]
194409a117 🎈 auto fixes by pre-commit hooks 2025-03-09 13:23:52 +00:00
roeseth
27978b216d use system local timezone instead of hardcoded UTC+8 2025-03-09 06:18:53 -07:00
Soulter
c38fa77ce6 🐛 fix: 修复 gewechat 部分场景下下载图片报错 #700 2025-03-09 18:10:38 +08:00
Soulter
3eb49f7422 feat: 支持设置私聊是否需要唤醒前缀唤醒 #735 2025-03-09 18:03:23 +08:00
Soulter
1989d615d2 🌈 style: format codes 2025-03-09 17:48:59 +08:00
Soulter
239412d265 feat: 支持接入钉钉 #643 2025-03-09 17:47:51 +08:00
Soulter
375a419a9e Merge pull request #732 from xiewoc/master
Update aiocqhttp_platform_adapter.py
2025-03-09 12:36:48 +08:00
Soulter
875c8ab424 ci: upate astrbot webui build cis 2025-03-09 11:31:10 +08:00
Soulter
c9bfc810ce ci: upload astrbot webui build ci 2025-03-09 11:26:10 +08:00
Soulter
46ecb16949 🐛 fix: 无法正常保存插件的 list 类型配置 #737 2025-03-09 11:12:24 +08:00
pre-commit-ci[bot]
d6a785b645 🎈 auto fixes by pre-commit hooks 2025-03-08 04:33:19 +00:00
xiewoc
79db828a01 Update aiocqhttp_platform_adapter.py 2025-03-08 12:30:49 +08:00
258 changed files with 28139 additions and 4396 deletions

View File

@@ -18,3 +18,7 @@ ENV/
README*.md README*.md
dashboard/ dashboard/
data/ data/
changelogs/
tests/
.ruff_cache/
.astrbot

15
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,15 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: astrbot
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: ['https://afdian.com/a/astrbot_team']

View File

@@ -6,7 +6,7 @@ body:
- type: markdown - type: markdown
attributes: attributes:
value: | value: |
欢迎发布插件到插件市场! 欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
- type: textarea - type: textarea
attributes: attributes:
@@ -22,9 +22,10 @@ body:
插件名: 插件名:
插件作者: 插件作者:
插件简介: 插件简介:
标签: (可选) 支持的消息平台:(必填,如 QQ、微信、飞书)
社交链接: (可选, 将会在插件市场作者名称上作为可点击的链接) 标签:(可选)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。 社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
- type: checkboxes - type: checkboxes
attributes: attributes:

View File

@@ -1,5 +1,5 @@
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE --> <!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
修复#XYZ 解决#XYZ
### Motivation ### Motivation
@@ -8,3 +8,12 @@
### Modifications ### Modifications
<!--简单解释你的改动--> <!--简单解释你的改动-->
### Check
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 👀 我的更改经过良好的测试
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt``pyproject.toml` 文件相应位置。
- [ ] 😮 我的更改没有引入恶意代码

View File

@@ -7,7 +7,7 @@ on:
name: Auto Release name: Auto Release
jobs: jobs:
build: build-and-publish-to-github-release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: write
@@ -28,8 +28,35 @@ jobs:
run: | run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV" echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
- name: Create Release - name: Create GitHub Release
uses: ncipollo/release-action@v1 uses: ncipollo/release-action@v1
with: with:
bodyFile: ${{ env.changelog }} bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip" artifacts: "dashboard/dist.zip"
build-and-publish-to-pypi:
# 构建并发布到 PyPI
runs-on: ubuntu-latest
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install uv
run: |
python -m pip install uv
- name: Build package
run: |
uv build
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
uv publish

31
.github/workflows/dashboard_ci.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: AstrBot Dashboard CI
on: [push]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: npm install, build
run: |
cd dashboard
npm install
npm run build
- name: Inject Commit SHA
id: get_sha
run: |
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
mkdir -p dashboard/dist/assets
echo $COMMIT_SHA > dashboard/dist/assets/version
- name: Archive production artifacts
uses: actions/upload-artifact@v4
with:
name: dist-without-markdown
path: |
dashboard/dist
!dist/**/*.md

5
.gitignore vendored
View File

@@ -1,6 +1,8 @@
__pycache__ __pycache__
botpy.log botpy.log
.vscode .vscode
.venv*
.idea
data_v2.db data_v2.db
data_v3.db data_v3.db
configs/session configs/session
@@ -26,3 +28,6 @@ venv/*
packages/python_interpreter/workplace packages/python_interpreter/workplace
.venv/* .venv/*
.conda/ .conda/
.idea
pytest.ini
.astrbot

View File

@@ -7,7 +7,7 @@ ci:
autoupdate_commit_msg: ":balloon: pre-commit autoupdate" autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.9 rev: v0.11.2
hooks: hooks:
- id: ruff - id: ruff
- id: ruff-format - id: ruff-format

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.10

View File

@@ -4,19 +4,32 @@ WORKDIR /AstrBot
COPY . /AstrBot/ COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
nodejs \
npm \
gcc \ gcc \
build-essential \ build-essential \
python3-dev \ python3-dev \
libffi-dev \ libffi-dev \
libssl-dev \ libssl-dev \
ca-certificates \
bash \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN python -m pip install -r requirements.txt --no-cache-dir RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir # 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
EXPOSE 6185 EXPOSE 6185
EXPOSE 6186 EXPOSE 6186
CMD [ "python", "main.py" ] CMD [ "python", "main.py" ]

35
Dockerfile_with_node Normal file
View File

@@ -0,0 +1,35 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]

145
README.md
View File

@@ -1,6 +1,6 @@
<p align="center"> <p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512) ![yjtp](https://github.com/user-attachments/assets/dcc74009-c57e-4b66-9ae3-0a81fc001255)
</p> </p>
@@ -10,14 +10,14 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest) [![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python"> <img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a> <a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a> <a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) <a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600) [![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=3600&style=for-the-badge&color=3b618e)
[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg)](https://gitcode.com/Soulter/AstrBot) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> <a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> <a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
@@ -27,19 +27,34 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。 AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
> [!NOTE]
>
> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,`v3.5.10` 已经支持接入 WeChatPadPro 替换 gewechat 方式。详见文档 [WeChatPadPro](https://astrbot.app/deploy/platform/wechat/wechatpadpro.html)
## ✨ 近期更新
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能 ## ✨ 主要功能
> [!NOTE]
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper 1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper
2. **多消息平台接入**。支持接入 QQOneBot、QQ 频道、微信Gewechat、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。 2. **多消息平台接入**。支持接入 QQOneBot、QQ 频道、微信Gewechat、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。 3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。 4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。 5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。 6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP] > [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/) > WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
> >
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM无法在聊天页使用大模型。不要再修改 demo 的登录密码了 😭) > 用户名: `astrbot`, 密码: `astrbot`。
## ✨ 使用方式 ## ✨ 使用方式
@@ -49,30 +64,48 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
#### Windows 一键安装器部署 #### Windows 一键安装器部署
需要电脑上安装有 Python>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。 请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### Replit 部署 #### 宝塔面板部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot) 请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### CasaOS 部署 #### CasaOS 部署
社区贡献的部署方式。 社区贡献的部署方式。
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。 请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### 手动部署 #### 手动部署
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) > 推荐使用 `uv`
## 🚀 路线图 首先,安装 uv
### 垂类功能 ```bash
pip install uv
```
1. 更好的上下文管理:限制 token 总数、对话上下文总结 通过 Git Clone 安装 AstrBot
3. AstrBot in Minecraft
### 横功能 ```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
或者,直接通过 uvx 安装 AstrBot
```bash
mkdir astrbot && cd astrbot
uvx astrbot init
# uvx astrbot run
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ⚡ 消息平台支持情况 ## ⚡ 消息平台支持情况
@@ -80,10 +113,12 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| -------- | ------- | ------- | ------ | | -------- | ------- | ------- | ------ |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 | | QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 | | QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 | | 微信个人号 | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 | | Telegram | ✔ | 私聊、群聊 | 文字、图片 |
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 | | 企业微信 | ✔ | 私聊 | 文字、图片、语音 |
| 飞书 | ✔ | 聊 | 文字、图片 | | 微信客服 | ✔ | 聊 | 文字、图片 |
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - | | 微信对话开放平台 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - | | Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - | | WhatsApp | 🚧 | 计划内 | - |
@@ -93,20 +128,26 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| 名称 | 支持性 | 类型 | 备注 | | 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- | | -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 同时也支持 DeepSeek、Google Gemini、GLM智谱、Moonshot月之暗面、阿里云百炼、硅基流动、xAI 等所有兼容 OpenAI API 的服务 | | OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | | | Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | | | Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | | | Dify | ✔ | LLMOps | |
| DashScope(阿里云百炼应用) | ✔ | LLMOps | | | 阿里云百炼应用 | ✔ | LLMOps | |
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 | | Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 | | LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 | | LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
| 硅基流动 | ✔ | 模型 API 服务平台 | |
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
| OneAPI | ✔ | LLM 分发系统 | | | OneAPI | ✔ | LLM 分发系统 | |
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 | | Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 | | SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | | | OpenAI TTS API | ✔ | 文本转语音 | |
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 | | GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS | | FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
## ❤️ 贡献 ## ❤️ 贡献
@@ -134,38 +175,45 @@ pre-commit install
## ✨ Demo ## ✨ Demo
> [!NOTE] <details><summary>👉 点击展开多张 Demo 截图 👈</summary>
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
<div align='center'> <div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600"> <img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器Beta 测试✨_ _✨基于 Docker 的沙箱化代码执行器Beta 测试✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500> <img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_ _✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150> <img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150> <img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_ _✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600> <img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
_✨ 管理面板 ✨_ _✨ WebUI ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat在线与机器人交互 ✨_
</div> </div>
</details>
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
此外,本项目的诞生离不开以下开源项目:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ)
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
## ⭐ Star History ## ⭐ Star History
> [!TIP] > [!TIP]
@@ -183,16 +231,5 @@ _✨ 内置 Web Chat在线与机器人交互 ✨_
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility. 2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project. 3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
_私は、高性能ですから!_ _私は、高性能ですから!_

View File

@@ -28,7 +28,7 @@ AstrBot is a loosely coupled, asynchronous chatbot and development framework tha
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper). 1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation. 2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://astrbot.app/others/dify.html) for easy access to Dify assistants/knowledge bases/workflows. 3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins. 4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction. 5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling. 6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.

View File

@@ -28,7 +28,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。 1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechat、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。 2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechat、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。 3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。 4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。 5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。 6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。

View File

@@ -5,6 +5,7 @@ from astrbot.core.platform import (
MessageMember, MessageMember,
MessageType, MessageType,
PlatformMetadata, PlatformMetadata,
Group,
) )
from astrbot.core.platform.register import register_platform_adapter from astrbot.core.platform.register import register_platform_adapter
@@ -18,4 +19,5 @@ __all__ = [
"MessageType", "MessageType",
"PlatformMetadata", "PlatformMetadata",
"register_platform_adapter", "register_platform_adapter",
"Group",
] ]

View File

@@ -1,5 +1,5 @@
from astrbot.core.provider import Provider, STTProvider, Personality from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ( from astrbot.core.provider.entities import (
ProviderRequest, ProviderRequest,
ProviderType, ProviderType,
ProviderMetaData, ProviderMetaData,

View File

@@ -2,11 +2,7 @@ from astrbot.core.star.register import (
register_star as register, # 注册插件Star register_star as register, # 注册插件Star
) )
from astrbot.core.star import Context, Star from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import * from astrbot.core.star.config import *
__all__ = [ __all__ = ["register", "Context", "Star", "StarTools"]
"register",
"Context",
"Star",
]

1
astrbot/cli/__init__.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "3.5.8"

59
astrbot/cli/__main__.py Normal file
View File

@@ -0,0 +1,59 @@
"""
AstrBot CLI入口
"""
import click
import sys
from . import __version__
from .commands import init, run, plug, conf
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
@click.group()
@click.version_option(__version__, prog_name="AstrBot")
def cli() -> None:
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot CLI version: {__version__}")
@click.command()
@click.argument("command_name", required=False, type=str)
def help(command_name: str | None) -> None:
"""显示命令的帮助信息
如果提供了 COMMAND_NAME则显示该命令的详细帮助信息。
否则,显示通用帮助信息。
"""
ctx = click.get_current_context()
if command_name:
# 查找指定命令
command = cli.get_command(ctx, command_name)
if command:
# 显示特定命令的帮助信息
click.echo(command.get_help(ctx))
else:
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# 显示通用帮助信息
click.echo(cli.get_help(ctx))
cli.add_command(init)
cli.add_command(run)
cli.add_command(help)
cli.add_command(plug)
cli.add_command(conf)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,6 @@
from .cmd_init import init
from .cmd_run import run
from .cmd_plug import plug
from .cmd_conf import conf
__all__ = ["init", "run", "plug", "conf"]

View File

@@ -0,0 +1,206 @@
import json
import click
import hashlib
import zoneinfo
from typing import Any, Callable
from ..utils import get_astrbot_root, check_astrbot_root
def _validate_log_level(value: str) -> str:
"""验证日志级别"""
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一"
)
return value
def _validate_dashboard_port(value: str) -> int:
"""验证 Dashboard 端口"""
try:
port = int(value)
if port < 1 or port > 65535:
raise click.ClickException("端口必须在 1-65535 范围内")
return port
except ValueError:
raise click.ClickException("端口必须是数字")
def _validate_dashboard_username(value: str) -> str:
"""验证 Dashboard 用户名"""
if not value:
raise click.ClickException("用户名不能为空")
return value
def _validate_dashboard_password(value: str) -> str:
"""验证 Dashboard 密码"""
if not value:
raise click.ClickException("密码不能为空")
return hashlib.md5(value.encode()).hexdigest()
def _validate_timezone(value: str) -> str:
"""验证时区"""
try:
zoneinfo.ZoneInfo(value)
except Exception:
raise click.ClickException(f"无效的时区: {value}请使用有效的IANA时区名称")
return value
def _validate_callback_api_base(value: str) -> str:
"""验证回调接口基址"""
if not value.startswith("http://") and not value.startswith("https://"):
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
return value
# 可通过CLI设置的配置项配置键到验证器函数的映射
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
"timezone": _validate_timezone,
"log_level": _validate_log_level,
"dashboard.port": _validate_dashboard_port,
"dashboard.username": _validate_dashboard_username,
"dashboard.password": _validate_dashboard_password,
"callback_api_base": _validate_callback_api_base,
}
def _load_config() -> dict[str, Any]:
"""加载或初始化配置文件"""
root = get_astrbot_root()
if not check_astrbot_root(root):
raise click.ClickException(
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
config_path = root / "data" / "cmd_config.json"
if not config_path.exists():
from astrbot.core.config.default import DEFAULT_CONFIG
config_path.write_text(
json.dumps(DEFAULT_CONFIG, ensure_ascii=False, indent=2),
encoding="utf-8-sig",
)
try:
return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e:
raise click.ClickException(f"配置文件解析失败: {str(e)}")
def _save_config(config: dict[str, Any]) -> None:
"""保存配置文件"""
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig"
)
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
"""设置嵌套字典中的值"""
parts = path.split(".")
for part in parts[:-1]:
if part not in obj:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典"
)
obj = obj[part]
obj[parts[-1]] = value
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
"""获取嵌套字典中的值"""
parts = path.split(".")
for part in parts:
obj = obj[part]
return obj
@click.group(name="conf")
def conf():
"""配置管理命令
支持的配置项:
- timezone: 时区设置 (例如: Asia/Shanghai)
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- dashboard.port: Dashboard 端口
- dashboard.username: Dashboard 用户名
- dashboard.password: Dashboard 密码
- callback_api_base: 回调接口基址
"""
pass
@conf.command(name="set")
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str):
"""设置配置项的值"""
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
config = _load_config()
try:
old_value = _get_nested_item(config, key)
validated_value = CONFIG_VALIDATORS[key](value)
_set_nested_item(config, key, validated_value)
_save_config(config)
click.echo(f"配置已更新: {key}")
if key == "dashboard.password":
click.echo(" 原值: ********")
click.echo(" 新值: ********")
else:
click.echo(f" 原值: {old_value}")
click.echo(f" 新值: {validated_value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"设置配置失败: {str(e)}")
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str = None):
"""获取配置项的值不提供key则显示所有可配置项"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS.keys():
raise click.ClickException(f"不支持的配置项: {key}")
try:
value = _get_nested_item(config, key)
if key == "dashboard.password":
value = "********"
click.echo(f"{key}: {value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
except Exception as e:
raise click.UsageError(f"获取配置失败: {str(e)}")
else:
click.echo("当前配置:")
for key in CONFIG_VALIDATORS.keys():
try:
value = (
"********"
if key == "dashboard.password"
else _get_nested_item(config, key)
)
click.echo(f" {key}: {value}")
except (KeyError, TypeError):
pass

View File

@@ -0,0 +1,55 @@
import asyncio
import click
from filelock import FileLock, Timeout
from ..utils import check_dashboard, get_astrbot_root
async def initialize_astrbot(astrbot_root) -> None:
"""执行 AstrBot 初始化逻辑"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
click.echo(f"Current Directory: {astrbot_root}")
click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
)
if click.confirm(
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
paths = {
"data": astrbot_root / "data",
"config": astrbot_root / "data" / "config",
"plugins": astrbot_root / "data" / "plugins",
"temp": astrbot_root / "data" / "temp",
}
for name, path in paths.items():
path.mkdir(parents=True, exist_ok=True)
click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}")
await check_dashboard(astrbot_root / "data")
@click.command()
def init() -> None:
"""初始化 AstrBot"""
click.echo("Initializing AstrBot...")
astrbot_root = get_astrbot_root()
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
try:
with lock.acquire():
asyncio.run(initialize_astrbot(astrbot_root))
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"初始化失败: {e!s}")

View File

@@ -0,0 +1,247 @@
import re
from pathlib import Path
import click
import shutil
from ..utils import (
get_git_repo,
build_plug_list,
manage_plugin,
PluginStatus,
check_astrbot_root,
get_astrbot_root,
)
@click.group()
def plug():
"""插件管理"""
pass
def _get_data_path() -> Path:
base = get_astrbot_root()
if not check_astrbot_root(base):
raise click.ClickException(
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
return (base / "data").resolve()
def display_plugins(plugins, title=None, color=None):
if title:
click.echo(click.style(title, fg=color, bold=True))
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
click.echo("-" * 85)
for p in plugins:
desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "")
click.echo(
f"{p['name']:<20} {p['version']:<10} {p['status']:<10} "
f"{p['author']:<15} {desc:<30}"
)
@plug.command()
@click.argument("name")
def new(name: str):
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
if plug_path.exists():
raise click.ClickException(f"插件 {name} 已存在")
author = click.prompt("请输入插件作者", type=str)
desc = click.prompt("请输入插件描述", type=str)
version = click.prompt("请输入插件版本", type=str)
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
raise click.ClickException("版本号必须为 x.y 或 x.y.z 格式")
repo = click.prompt("请输入插件仓库:", type=str)
if not repo.startswith("http"):
raise click.ClickException("仓库地址必须以 http 开头")
click.echo("下载插件模板...")
get_git_repo(
"https://github.com/Soulter/helloworld",
plug_path,
)
click.echo("重写插件信息...")
# 重写 metadata.yaml
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
f.write(
f"name: {name}\n"
f"desc: {desc}\n"
f"version: {version}\n"
f"author: {author}\n"
f"repo: {repo}\n"
)
# 重写 README.md
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
# 重写 main.py
with open(plug_path / "main.py", "r", encoding="utf-8") as f:
content = f.read()
new_content = content.replace(
'@register("helloworld", "YourName", "一个简单的 Hello World 插件", "1.0.0")',
f'@register("{name}", "{author}", "{desc}", "{version}")',
)
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
f.write(new_content)
click.echo(f"插件 {name} 创建成功")
@plug.command()
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
def list(all: bool):
"""列出插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
# 未发布的插件
not_published_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
]
if not_published_plugins:
display_plugins(not_published_plugins, "未发布的插件", "red")
# 需要更新的插件
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if need_update_plugins:
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
# 已安装的插件
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
if installed_plugins:
display_plugins(installed_plugins, "已安装的插件", "green")
# 未安装的插件
not_installed_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
]
if not_installed_plugins and all:
display_plugins(not_installed_plugins, "未安装的插件", "blue")
if (
not any([not_published_plugins, need_update_plugins, installed_plugins])
and not all
):
click.echo("未安装任何插件")
@plug.command()
@click.argument("name")
@click.option("--proxy", help="代理服务器地址")
def install(name: str, proxy: str | None):
"""安装插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
plugin = next(
(
p
for p in plugins
if p["name"] == name and p["status"] == PluginStatus.NOT_INSTALLED
),
None,
)
if not plugin:
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
@plug.command()
@click.argument("name")
def remove(name: str):
"""卸载插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
plugin = next((p for p in plugins if p["name"] == name), None)
if not plugin or not plugin.get("local_path"):
raise click.ClickException(f"插件 {name} 不存在或未安装")
plugin_path = plugin["local_path"]
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
try:
shutil.rmtree(plugin_path)
click.echo(f"插件 {name} 已卸载")
except Exception as e:
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
@plug.command()
@click.argument("name", required=False)
@click.option("--proxy", help="Github代理地址")
def update(name: str, proxy: str | None):
"""更新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
if name:
plugin = next(
(
p
for p in plugins
if p["name"] == name and p["status"] == PluginStatus.NEED_UPDATE
),
None,
)
if not plugin:
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
else:
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if not need_update_plugins:
click.echo("没有需要更新的插件")
return
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
for plugin in need_update_plugins:
plugin_name = plugin["name"]
click.echo(f"正在更新插件 {plugin_name}...")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@plug.command()
@click.argument("query")
def search(query: str):
"""搜索插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
matched_plugins = [
p
for p in plugins
if query.lower() in p["name"].lower()
or query.lower() in p["desc"].lower()
or query.lower() in p["author"].lower()
]
if not matched_plugins:
click.echo(f"未找到匹配 '{query}' 的插件")
return
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")

View File

@@ -0,0 +1,63 @@
import os
import sys
from pathlib import Path
import click
import asyncio
import traceback
from filelock import FileLock, Timeout
from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root
async def run_astrbot(astrbot_root: Path):
"""运行 AstrBot"""
from astrbot.core import logger, LogManager, LogBroker, db_helper
from astrbot.core.initial_loader import InitialLoader
await check_dashboard(astrbot_root / "data")
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
db = db_helper
core_lifecycle = InitialLoader(db, log_broker)
await core_lifecycle.start()
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
@click.command()
def run(reload: bool, port: str) -> None:
"""运行 AstrBot"""
try:
os.environ["ASTRBOT_CLI"] = "1"
astrbot_root = get_astrbot_root()
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init"
)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
sys.path.insert(0, str(astrbot_root))
if port:
os.environ["DASHBOARD_PORT"] = port
if reload:
click.echo("启用插件自动重载")
os.environ["ASTRBOT_RELOAD"] = "1"
lock_file = astrbot_root / "astrbot.lock"
lock = FileLock(lock_file, timeout=5)
with lock.acquire():
asyncio.run(run_astrbot(astrbot_root))
except KeyboardInterrupt:
click.echo("AstrBot 已关闭...")
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
except Exception as e:
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")

View File

@@ -0,0 +1,18 @@
from .basic import (
get_astrbot_root,
check_astrbot_root,
check_dashboard,
)
from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus
from .version_comparator import VersionComparator
__all__ = [
"get_astrbot_root",
"check_astrbot_root",
"check_dashboard",
"get_git_repo",
"manage_plugin",
"build_plug_list",
"VersionComparator",
"PluginStatus",
]

View File

@@ -0,0 +1,67 @@
from pathlib import Path
import click
def check_astrbot_root(path: str | Path) -> bool:
"""检查路径是否为 AstrBot 根目录"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
return False
if not (path / ".astrbot").exists():
return False
return True
def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径"""
return Path.cwd()
async def check_dashboard(astrbot_root: Path) -> None:
"""检查是否安装了dashboard"""
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
from astrbot.core.config.default import VERSION
from .version_comparator import VersionComparator
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo("未安装管理面板")
if click.confirm(
"是否安装管理面板?",
default=True,
abort=True,
):
click.echo("正在安装管理面板...")
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
click.echo("管理面板安装完成")
case str():
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("管理面板已是最新版本")
return
else:
try:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return
except FileNotFoundError:
click.echo("初始化管理面板目录...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
)
click.echo("管理面板初始化完成")
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return

230
astrbot/cli/utils/plugin.py Normal file
View File

@@ -0,0 +1,230 @@
import shutil
import tempfile
import httpx
import yaml
from enum import Enum
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile
import click
from .version_comparator import VersionComparator
class PluginStatus(str, Enum):
INSTALLED = "已安装"
NEED_UPDATE = "需更新"
NOT_INSTALLED = "未安装"
NOT_PUBLISHED = "未发布"
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
"""从 Git 仓库下载代码并解压到指定路径"""
temp_dir = Path(tempfile.mkdtemp())
try:
# 解析仓库信息
repo_namespace = url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
# 尝试获取最新的 release
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try:
with httpx.Client(
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(release_url)
resp.raise_for_status()
releases = resp.json()
if releases:
# 使用最新的 release
download_url = releases[0]["zipball_url"]
else:
# 没有 release使用默认分支
click.echo(f"正在从默认分支下载 {author}/{repo}")
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
except Exception as e:
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
download_url = url
# 应用代理
if proxy:
download_url = f"{proxy}/{download_url}"
# 下载并解压
with httpx.Client(
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(download_url)
if (
resp.status_code == 404
and "archive/refs/heads/master.zip" in download_url
):
alt_url = download_url.replace("master.zip", "main.zip")
click.echo("master 分支不存在,尝试下载 main 分支")
resp = client.get(alt_url)
resp.raise_for_status()
else:
resp.raise_for_status()
zip_content = BytesIO(resp.content)
with ZipFile(zip_content) as z:
z.extractall(temp_dir)
namelist = z.namelist()
root_dir = Path(namelist[0]).parts[0] if namelist else ""
if target_path.exists():
shutil.rmtree(target_path)
shutil.move(temp_dir / root_dir, target_path)
finally:
if temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)
def load_yaml_metadata(plugin_dir: Path) -> dict:
"""从 metadata.yaml 文件加载插件元数据
Args:
plugin_dir: 插件目录路径
Returns:
dict: 包含元数据的字典,如果读取失败则返回空字典
"""
yaml_path = plugin_dir / "metadata.yaml"
if yaml_path.exists():
try:
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
except Exception as e:
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
return {}
def build_plug_list(plugins_dir: Path) -> list:
"""构建插件列表,包含本地和在线插件信息
Args:
plugins_dir (Path): 插件目录路径
Returns:
list: 包含插件信息的字典列表
"""
# 获取本地插件信息
result = []
if plugins_dir.exists():
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
# 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
result.append({
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
})
# 获取在线插件列表
online_plugins = []
try:
with httpx.Client() as client:
resp = client.get("https://api.soulter.top/astrbot/plugins")
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append({
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
})
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)
# 与在线插件比对,更新状态
online_plugin_names = {plugin["name"] for plugin in online_plugins}
for local_plugin in result:
if local_plugin["name"] in online_plugin_names:
# 查找对应的在线插件
online_plugin = next(
p for p in online_plugins if p["name"] == local_plugin["name"]
)
if (
VersionComparator.compare_version(
local_plugin["version"], online_plugin["version"]
)
< 0
):
local_plugin["status"] = PluginStatus.NEED_UPDATE
else:
# 本地插件未在线上发布
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
# 添加未安装的在线插件
for online_plugin in online_plugins:
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
result.append(online_plugin)
return result
def manage_plugin(
plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None
) -> None:
"""安装或更新插件
Args:
plugin (dict): 插件信息字典
plugins_dir (Path): 插件目录
is_update (bool, optional): 是否为更新操作. 默认为 False
proxy (str, optional): 代理服务器地址
"""
plugin_name = plugin["name"]
repo_url = plugin["repo"]
# 如果是更新且有本地路径,直接使用本地路径
if is_update and plugin.get("local_path"):
target_path = Path(plugin["local_path"])
else:
target_path = plugins_dir / plugin_name
backup_path = Path(f"{target_path}_backup") if is_update else None
# 检查插件是否存在
if is_update and not target_path.exists():
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
# 备份现有插件
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
if is_update:
shutil.copytree(target_path, backup_path)
try:
click.echo(
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..."
)
get_git_repo(repo_url, target_path, proxy)
# 更新成功,删除备份
if is_update and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}"
)

View File

@@ -0,0 +1,92 @@
"""
拷贝自 astrbot.core.utils.version_comparator
"""
import re
class VersionComparator:
@staticmethod
def compare_version(v1: str, v2: str) -> int:
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
参考: https://semver.org/lang/zh-CN/
返回 1 表示 v1 > v2返回 -1 表示 v1 < v2返回 0 表示 v1 = v2。
"""
v1 = v1.lower().replace("v", "")
v2 = v2.lower().replace("v", "")
def split_version(version):
match = re.match(
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
version,
)
if not match:
return [], None
major_minor_patch = match.group(1).split(".")
prerelease = match.group(2)
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
parts = [int(x) for x in major_minor_patch]
prerelease = VersionComparator._split_prerelease(prerelease)
return parts, prerelease
v1_parts, v1_prerelease = split_version(v1)
v2_parts, v2_prerelease = split_version(v2)
# 比较数字部分
length = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (length - len(v1_parts)))
v2_parts.extend([0] * (length - len(v2_parts)))
for i in range(length):
if v1_parts[i] > v2_parts[i]:
return 1
elif v1_parts[i] < v2_parts[i]:
return -1
# 比较预发布标签
if v1_prerelease is None and v2_prerelease is not None:
return 1 # 没有预发布标签的版本高于有预发布标签的版本
elif v1_prerelease is not None and v2_prerelease is None:
return -1 # 有预发布标签的版本低于没有预发布标签的版本
elif v1_prerelease is not None and v2_prerelease is not None:
len_pre = max(len(v1_prerelease), len(v2_prerelease))
for i in range(len_pre):
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
if p1 is None and p2 is not None:
return -1
elif p1 is not None and p2 is None:
return 1
elif isinstance(p1, int) and isinstance(p2, str):
return -1
elif isinstance(p1, str) and isinstance(p2, int):
return 1
elif isinstance(p1, int) and isinstance(p2, int):
if p1 > p2:
return 1
elif p1 < p2:
return -1
elif isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
elif p1 < p2:
return -1
return 0 # 预发布标签完全相同
return 0 # 数字部分和预发布标签都相同
@staticmethod
def _split_prerelease(prerelease):
if not prerelease:
return None
parts = prerelease.split(".")
result = []
for part in parts:
if part.isdigit():
result.append(int(part))
else:
result.append(part)
return result

View File

@@ -7,20 +7,28 @@ from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig from astrbot.core.config import AstrBotConfig
from astrbot.core.file_token_service import FileTokenService
from .utils.astrbot_path import get_astrbot_data_path
os.makedirs("data", exist_ok=True) # 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
DEMO_MODE = os.getenv("DEMO_MODE", False)
astrbot_config = AstrBotConfig() astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
html_renderer = HtmlRenderer(t2i_base_url) html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot") logger = LogManager.GetLogger(log_name="astrbot")
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
db_helper = SQLiteDatabase(DB_PATH) db_helper = SQLiteDatabase(DB_PATH)
sp = SharedPreferences() # 简单的偏好设置存储 # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", "")) sp = SharedPreferences()
# 文件令牌服务
file_token_service = FileTokenService()
pip_installer = PipInstaller(
astrbot_config.get("pip_install_arg", ""),
astrbot_config.get("pypi_index_url", None),
)
web_chat_queue = asyncio.Queue(maxsize=32) web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32) web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"

View File

@@ -4,8 +4,9 @@ import logging
import enum import enum
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict from typing import Dict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
ASTRBOT_CONFIG_PATH = "data/cmd_config.json" ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
logger = logging.getLogger("astrbot") logger = logging.getLogger("astrbot")
@@ -45,8 +46,6 @@ class AstrBotConfig(dict):
with open(config_path, "r", encoding="utf-8-sig") as f: with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read() conf_str = f.read()
if conf_str.startswith("/ufeff"): # remove BOM
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
conf = json.loads(conf_str) conf = json.loads(conf_str)
# 检查配置完整性,并插入 # 检查配置完整性,并插入

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,10 @@
"""
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid import uuid
import json import json
import asyncio import asyncio
@@ -11,24 +18,34 @@ class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase): def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次 self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save() self._start_periodic_save()
def _start_periodic_save(self): def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save()) asyncio.create_task(self._periodic_save())
async def _periodic_save(self): async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True: while True:
await asyncio.sleep(self.save_interval) await asyncio.sleep(self.save_interval)
self._save_to_storage() self._save_to_storage()
def _save_to_storage(self): def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations) sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str: async def new_conversation(self, unified_msg_origin: str) -> str:
"""新建对话,并将当前会话的对话转移到新对话""" """新建对话,并将当前会话的对话转移到新对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = str(uuid.uuid4()) conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id) self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.session_conversations[unified_msg_origin] = conversation_id self.session_conversations[unified_msg_origin] = conversation_id
@@ -36,14 +53,24 @@ class ConversationManager:
return conversation_id return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话""" """切换会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations) sp.put("session_conversation", self.session_conversations)
async def delete_conversation( async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str = None self, unified_msg_origin: str, conversation_id: str = None
): ):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话""" """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = self.session_conversations.get(unified_msg_origin) conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id: if conversation_id:
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id) self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
@@ -51,23 +78,48 @@ class ConversationManager:
sp.put("session_conversation", self.session_conversations) sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str: async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
"""获取会话当前的对话 ID""" """获取会话当前的对话 ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
return self.session_conversations.get(unified_msg_origin, None) return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation( async def get_conversation(
self, unified_msg_origin: str, conversation_id: str self, unified_msg_origin: str, conversation_id: str
) -> Conversation: ) -> Conversation:
"""获取会话的对话""" """获取会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Returns:
conversation (Conversation): 对话对象
"""
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]: async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话""" """获取会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversations (List[Conversation]): 对话对象列表
"""
return self.db.get_conversations(unified_msg_origin) return self.db.get_conversations(unified_msg_origin)
async def update_conversation( async def update_conversation(
self, unified_msg_origin: str, conversation_id: str, history: List[Dict] self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
): ):
"""更新会话的对话""" """更新会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if conversation_id: if conversation_id:
self.db.update_conversation( self.db.update_conversation(
user_id=unified_msg_origin, user_id=unified_msg_origin,
@@ -76,7 +128,12 @@ class ConversationManager:
) )
async def update_conversation_title(self, unified_msg_origin: str, title: str): async def update_conversation_title(self, unified_msg_origin: str, title: str):
"""更新会话的对话标题""" """更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
"""
conversation_id = self.session_conversations.get(unified_msg_origin) conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id: if conversation_id:
self.db.update_conversation_title( self.db.update_conversation_title(
@@ -86,7 +143,12 @@ class ConversationManager:
async def update_conversation_persona_id( async def update_conversation_persona_id(
self, unified_msg_origin: str, persona_id: str self, unified_msg_origin: str, persona_id: str
): ):
"""更新会话的对话 Persona ID""" """更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
"""
conversation_id = self.session_conversations.get(unified_msg_origin) conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id: if conversation_id:
self.db.update_conversation_persona_id( self.db.update_conversation_persona_id(
@@ -96,6 +158,14 @@ class ConversationManager:
async def get_human_readable_context( async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10 self, unified_msg_origin, conversation_id, page=1, page_size=10
): ):
"""获取人类可读的上下文
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
conversation = await self.get_conversation(unified_msg_origin, conversation_id) conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history) history = json.loads(conversation.history)
@@ -105,7 +175,15 @@ class ConversationManager:
if record["role"] == "user": if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}") temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant": elif record["role"] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}") if "content" in record and record["content"]:
temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record:
tool_calls_str = json.dumps(
record["tool_calls"], ensure_ascii=False
)
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else:
temp_contexts.append("Assistant: [未知的内容]")
contexts.insert(0, temp_contexts) contexts.insert(0, temp_contexts)
temp_contexts = [] temp_contexts = []

View File

@@ -1,3 +1,14 @@
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
工作流程:
1. 初始化所有组件
2. 启动事件总线和任务, 所有任务都在这里运行
3. 执行启动完成事件钩子
"""
import traceback import traceback
import asyncio import asyncio
import time import time
@@ -17,39 +28,54 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.config.default import VERSION from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle: class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase): """
self.log_broker = log_broker AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
self.astrbot_config = astrbot_config 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
self.db = db EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker # 初始化日志代理
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
# 根据环境变量设置代理
os.environ["https_proxy"] = self.astrbot_config["http_proxy"] os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"] os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
os.environ["no_proxy"] = "localhost" os.environ["no_proxy"] = "localhost"
async def initialize(self): async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
logger.info("AstrBot v" + VERSION) logger.info("AstrBot v" + VERSION)
if os.environ.get("TESTING", ""): if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG") logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
else: else:
logger.setLevel(self.astrbot_config["log_level"]) logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
self.event_queue = Queue()
self.event_queue.closed = False
# 初始化事件队列
self.event_queue = Queue()
# 初始化供应商管理器
self.provider_manager = ProviderManager(self.astrbot_config, self.db) self.provider_manager = ProviderManager(self.astrbot_config, self.db)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config) # 初始化对话管理器
self.conversation_manager = ConversationManager(self.db) self.conversation_manager = ConversationManager(self.db)
# 初始化提供给插件的上下文
self.star_context = Context( self.star_context = Context(
self.event_queue, self.event_queue,
self.astrbot_config, self.astrbot_config,
@@ -57,35 +83,51 @@ class AstrBotCoreLifecycle:
self.provider_manager, self.provider_manager,
self.platform_manager, self.platform_manager,
self.conversation_manager, self.conversation_manager,
self.knowledge_db_manager,
) )
# 初始化插件管理器
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
# 扫描、注册插件、实例化插件类
await self.plugin_manager.reload() await self.plugin_manager.reload()
"""扫描、注册插件、实例化插件类"""
# 根据配置实例化各个 Provider
await self.provider_manager.initialize() await self.provider_manager.initialize()
"""根据配置实例化各个 Provider"""
# 初始化消息事件流水线调度器
self.pipeline_scheduler = PipelineScheduler( self.pipeline_scheduler = PipelineScheduler(
PipelineContext(self.astrbot_config, self.plugin_manager) PipelineContext(self.astrbot_config, self.plugin_manager)
) )
await self.pipeline_scheduler.initialize() await self.pipeline_scheduler.initialize()
"""初始化消息事件流水线调度器"""
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"]) # 初始化更新器
self.astrbot_updator = AstrBotUpdator()
# 初始化事件总线
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
# 记录启动时间
self.start_time = int(time.time()) self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = [] self.curr_tasks: List[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize() await self.platform_manager.initialize()
"""根据配置实例化各个平台适配器"""
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
def _load(self): def _load(self):
"""加载事件总线和任务并初始化"""
# 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task( event_bus_task = asyncio.create_task(
self.event_bus.dispatch(), name="event_bus" self.event_bus.dispatch(), name="event_bus"
) )
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = [] extra_tasks = []
for task in self.star_context._register_tasks: for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) extra_tasks.append(asyncio.create_task(task, name=task.__name__))
@@ -99,17 +141,24 @@ class AstrBotCoreLifecycle:
self.start_time = int(time.time()) self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task): async def _task_wrapper(self, task: asyncio.Task):
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
Args:
task (asyncio.Task): 要执行的异步任务
"""
try: try:
await task await task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass # 任务被取消, 静默处理
except Exception as e: except Exception as e:
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"): for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}") logger.error(f"| {line}")
logger.error("-------") logger.error("-------")
async def start(self): async def start(self):
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
self._load() self._load()
logger.info("AstrBot 启动完成。") logger.info("AstrBot 启动完成。")
@@ -126,15 +175,29 @@ class AstrBotCoreLifecycle:
except BaseException: except BaseException:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
# 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True) await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self): async def stop(self):
self.event_queue.closed = True """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
# 请求停止所有正在运行的异步任务
for task in self.curr_tasks: for task in self.curr_tasks:
task.cancel() task.cancel()
await self.provider_manager.terminate() for plugin in self.plugin_manager.context.get_all_stars():
try:
await self.plugin_manager._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
)
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
for task in self.curr_tasks: for task in self.curr_tasks:
try: try:
await task await task
@@ -143,13 +206,17 @@ class AstrBotCoreLifecycle:
except Exception as e: except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}") logger.error(f"任务 {task.get_name()} 发生错误: {e}")
def restart(self): async def restart(self):
self.event_queue.closed = True """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread( threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True target=self.astrbot_updator._reboot, name="restart", daemon=True
).start() ).start()
def load_platform(self) -> List[asyncio.Task]: def load_platform(self) -> List[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = [] tasks = []
platform_insts = self.platform_manager.get_insts() platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts: for platform_inst in platform_insts:

View File

@@ -1,6 +1,6 @@
import abc import abc
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List, Dict, Any, Tuple
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@@ -117,3 +117,45 @@ class BaseDatabase(abc.ABC):
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID""" """更新 Conversation Persona ID"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError

View File

@@ -6,6 +6,8 @@ from typing import List
@dataclass @dataclass
class Platform: class Platform:
"""平台使用统计数据"""
name: str name: str
count: int count: int
timestamp: int timestamp: int
@@ -13,6 +15,8 @@ class Platform:
@dataclass @dataclass
class Provider: class Provider:
"""供应商使用统计数据"""
name: str name: str
count: int count: int
timestamp: int timestamp: int
@@ -20,6 +24,8 @@ class Provider:
@dataclass @dataclass
class Plugin: class Plugin:
"""插件使用统计数据"""
name: str name: str
count: int count: int
timestamp: int timestamp: int
@@ -27,6 +33,8 @@ class Plugin:
@dataclass @dataclass
class Command: class Command:
"""命令使用统计数据"""
name: str name: str
count: int count: int
timestamp: int timestamp: int

View File

@@ -3,7 +3,7 @@ import os
import time import time
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
from . import BaseDatabase from . import BaseDatabase
from typing import Tuple from typing import Tuple, List, Dict, Any
class SQLiteDatabase(BaseDatabase): class SQLiteDatabase(BaseDatabase):
@@ -128,24 +128,23 @@ class SQLiteDatabase(BaseDatabase):
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor() c = self._get_conn(self.db_path).cursor()
where_clause = "" conditions = []
if session_id or provider_type: params = []
where_clause += " WHERE "
has = False if session_id:
if session_id: conditions.append("session_id = ?")
where_clause += f"session_id = '{session_id}'" params.append(session_id)
has = True
if provider_type: if provider_type:
if has: conditions.append("provider_type = ?")
where_clause += " AND " params.append(provider_type)
where_clause += f"provider_type = '{provider_type}'"
sql = "SELECT * FROM llm_history"
if conditions:
sql += " WHERE " + " AND ".join(conditions)
c.execute(sql, params)
c.execute(
"""
SELECT * FROM llm_history
"""
+ where_clause
)
res = c.fetchall() res = c.fetchall()
histories = [] histories = []
for row in res: for row in res:
@@ -389,3 +388,178 @@ class SQLiteDatabase(BaseDatabase):
if res: if res:
return ATRIVision(*res) return ATRIVision(*res)
return None return None
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0确保即使出错也有有效的返回值
return [], 0
finally:
c.close()

View File

@@ -38,11 +38,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
); );
CREATE TABLE IF NOT EXISTS webchat_conversation( CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT, user_id TEXT, -- 会话 id
cid TEXT, cid TEXT, -- 对话 id
history TEXT, history TEXT,
created_at INTEGER, created_at INTEGER,
updated_at INTEGER, updated_at INTEGER,
title TEXT, title TEXT,
persona_id TEXT persona_id TEXT
); );
PRAGMA encoding = 'UTF-8';

View File

@@ -0,0 +1,46 @@
import abc
from dataclasses import dataclass
@dataclass
class Result:
similarity: float
data: dict
class BaseVecDB:
async def initialize(self):
"""
初始化向量数据库
"""
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
top_k (int): 返回的最相似文档的数量
Returns:
List[Result]: 查询结果
"""
...
@abc.abstractmethod
async def delete(self, doc_id: str) -> bool:
"""
删除指定文档。
Args:
doc_id (str): 要删除的文档 ID
Returns:
bool: 删除是否成功
"""
...

View File

@@ -0,0 +1,3 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]

View File

@@ -0,0 +1,121 @@
import aiosqlite
import os
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
async def get_documents(self, metadata_filters: dict, ids: list = None):
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
"""
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
await self.connection.commit()
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
Returns:
list: A list of user IDs.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
return [row[0] for row in rows]
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
Args:
row (tuple): The row to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": row[0],
"doc_id": row[1],
"text": row[2],
"metadata": row[3],
"created_at": row[4],
"updated_at": row[5],
}
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None

View File

@@ -0,0 +1,59 @@
try:
import faiss
except ModuleNotFoundError:
raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。"
)
import os
import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str = None):
self.dimension = dimension
self.path = path
self.index = None
if path and os.path.exists(path):
self.index = faiss.read_index(path)
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
Args:
vector (np.ndarray): 要插入的向量
id (int): 向量的ID
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.storage[id] = vector
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
"""搜索最相似的向量
Args:
vector (np.ndarray): 查询向量
k (int): 返回的最相似向量的数量
Returns:
tuple: (距离, 索引)
"""
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def save_index(self):
"""保存索引
Args:
path (str): 保存索引的路径
"""
faiss.write_index(self.index, self.path)

View File

@@ -0,0 +1,17 @@
-- 创建文档存储表,包含 faiss 中文档的 id文档文本create_atupdated_at
CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id TEXT NOT NULL,
text TEXT NOT NULL,
metadata TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE documents
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
ALTER TABLE documents
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
CREATE INDEX idx_documents_user_id ON documents(user_id);
CREATE INDEX idx_documents_group_id ON documents(group_id);

View File

@@ -0,0 +1,117 @@
import uuid
import json
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
class FaissVecDB(BaseVecDB):
"""
A class to represent a vector database.
"""
def __init__(
self,
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
self.embedding_provider = embedding_provider
self.document_storage = DocumentStorage(doc_store_path)
self.embedding_storage = EmbeddingStorage(
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def retrieve(
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
metadata_filters (dict): 元数据过滤器
Returns:
List[Result]: 查询结果
"""
embedding = await self.embedding_provider.get_embedding(query)
scores, indices = await self.embedding_storage.search(
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
# TODO: rerank
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
scores[0] = 1.0 - (scores[0] / 2.0)
# NOTE: maybe the size is less than k.
fetched_docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters or {}, ids=indices[0]
)
if not fetched_docs:
return []
result_docs = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
pos = idx_pos.get(indice_idx)
if pos is None:
continue
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
return result_docs[:k]
async def delete(self, doc_id: int):
"""
删除一条文档
"""
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
async def close(self):
await self.document_storage.close()
async def count_documents(self) -> int:
"""
计算文档数量
"""
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0

View File

@@ -1,3 +1,16 @@
"""
事件总线, 用于处理事件的分发和处理
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
class:
EventBus: 事件总线, 用于处理事件的分发和处理
工作流程:
1. 维护一个异步队列, 来接受各种消息事件
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
"""
import asyncio import asyncio
from asyncio import Queue from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler from astrbot.core.pipeline.scheduler import PipelineScheduler
@@ -6,21 +19,38 @@ from .platform import AstrMessageEvent
class EventBus: class EventBus:
"""事件总线: 用于处理事件的分发和处理
维护一个异步队列, 来接受各种消息事件
"""
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler): def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue self.event_queue = event_queue # 事件队列
self.pipeline_scheduler = pipeline_scheduler self.pipeline_scheduler = pipeline_scheduler # 管道调度器
async def dispatch(self): async def dispatch(self):
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
while True: while True:
event: AstrMessageEvent = await self.event_queue.get() event: AstrMessageEvent = (
self._print_event(event) await self.event_queue.get()
asyncio.create_task(self.pipeline_scheduler.execute(event)) ) # 从事件队列中获取新的事件
self._print_event(event) # 打印日志
asyncio.create_task(
self.pipeline_scheduler.execute(event)
) # 创建新的异步任务来执行管道调度器的处理逻辑
def _print_event(self, event: AstrMessageEvent): def _print_event(self, event: AstrMessageEvent):
"""用于记录事件信息
Args:
event (AstrMessageEvent): 事件对象
"""
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name(): if event.get_sender_name():
logger.info( logger.info(
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
) )
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else: else:
logger.info( logger.info(
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}" f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"

View File

@@ -0,0 +1,68 @@
import asyncio
import os
import uuid
import time
class FileTokenService:
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
def __init__(self, default_timeout: float = 300):
self.lock = asyncio.Lock()
self.staged_files = {} # token: (file_path, expire_time)
self.default_timeout = default_timeout
async def _cleanup_expired_tokens(self):
"""清理过期的令牌"""
now = time.time()
expired_tokens = [token for token, (_, expire) in self.staged_files.items() if expire < now]
for token in expired_tokens:
self.staged_files.pop(token, None)
async def register_file(self, file_path: str, timeout: float = None) -> str:
"""向令牌服务注册一个文件。
Args:
file_path(str): 文件路径
timeout(float): 超时时间,单位秒(可选)
Returns:
str: 一个单次令牌
Raises:
FileNotFoundError: 当路径不存在时抛出
"""
async with self.lock:
await self._cleanup_expired_tokens()
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
expire_time = time.time() + (timeout if timeout is not None else self.default_timeout)
self.staged_files[file_token] = (file_path, expire_time)
return file_token
async def handle_file(self, file_token: str) -> str:
"""根据令牌获取文件路径,使用后令牌失效。
Args:
file_token(str): 注册时返回的令牌
Returns:
str: 文件路径
Raises:
KeyError: 当令牌不存在或已过期时抛出
FileNotFoundError: 当文件本身已被删除时抛出
"""
async with self.lock:
await self._cleanup_expired_tokens()
if file_token not in self.staged_files:
raise KeyError(f"无效或过期的文件 token: {file_token}")
file_path, _ = self.staged_files.pop(file_token)
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
return file_path

View File

@@ -1,35 +1,49 @@
"""
AstrBot 启动器负责初始化和启动核心组件和仪表板服务器
工作流程:
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
2. 运行核心生命周期任务和仪表板服务器
"""
import asyncio import asyncio
import traceback import traceback
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard
class AstrBotDashBoardLifecycle: class InitialLoader:
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
def __init__(self, db: BaseDatabase, log_broker: LogBroker): def __init__(self, db: BaseDatabase, log_broker: LogBroker):
self.db = db self.db = db
self.logger = logger self.logger = logger
self.log_broker = log_broker self.log_broker = log_broker
self.dashboard_server = None
async def start(self): async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
core_task = []
try: try:
await core_lifecycle.initialize() await core_lifecycle.initialize()
core_task = core_lifecycle.start()
except Exception as e: except Exception as e:
logger.critical(traceback.format_exc()) logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!") logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
return
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db) core_task = core_lifecycle.start()
task = asyncio.gather(core_task, self.dashboard_server.run())
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
)
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
try: try:
await task await task # 整个AstrBot在这里运行
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("🌈 正在关闭 AstrBot...") logger.info("🌈 正在关闭 AstrBot...")
await core_lifecycle.stop() await core_lifecycle.stop()

View File

@@ -1,12 +1,38 @@
"""
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
class:
LogBroker: 日志代理类, 用于缓存和分发日志消息
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
LogManager: 日志管理器, 用于创建和配置日志记录器
function:
is_plugin_path: 检查文件路径是否来自插件目录
get_short_level_name: 将日志级别名称转换为四个字母的缩写
工作流程:
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
"""
import logging import logging
import colorlog import colorlog
import asyncio import asyncio
import os import os
import sys
from collections import deque from collections import deque
from asyncio import Queue from asyncio import Queue
from typing import List from typing import List
# 日志缓存大小
CACHED_SIZE = 200 CACHED_SIZE = 200
# 日志颜色配置
log_color_config = { log_color_config = {
"DEBUG": "green", "DEBUG": "green",
"INFO": "bold_cyan", "INFO": "bold_cyan",
@@ -19,8 +45,13 @@ log_color_config = {
def is_plugin_path(pathname): def is_plugin_path(pathname):
""" """检查文件路径是否来自插件目录
检查文件路径是否来自插件目录
Args:
pathname (str): 文件路径
Returns:
bool: 如果路径来自插件目录,则返回 True否则返回 False
""" """
if not pathname: if not pathname:
return False return False
@@ -30,8 +61,13 @@ def is_plugin_path(pathname):
def get_short_level_name(level_name): def get_short_level_name(level_name):
""" """将日志级别名称转换为四个字母的缩写
将日志级别名称转换为四个字母的缩写
Args:
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
Returns:
str: 四个字母的日志级别缩写
""" """
level_map = { level_map = {
"DEBUG": "DBUG", "DEBUG": "DBUG",
@@ -44,12 +80,21 @@ def get_short_level_name(level_name):
class LogBroker: class LogBroker:
"""日志代理类, 用于缓存和分发日志消息
发布-订阅模式
"""
def __init__(self): def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE) self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: List[Queue] = [] self.subscribers: List[Queue] = [] # 订阅者列表
def register(self) -> Queue: def register(self) -> Queue:
"""给每个订阅者返回一个带有日志缓存的队列""" """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns:
Queue: 订阅者的队列, 可用于接收日志消息
"""
q = Queue(maxsize=CACHED_SIZE + 10) q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache: for log in self.log_cache:
q.put_nowait(log) q.put_nowait(log)
@@ -57,11 +102,20 @@ class LogBroker:
return q return q
def unregister(self, q: Queue): def unregister(self, q: Queue):
"""取消订阅""" """取消订阅
Args:
q (Queue): 需要取消订阅的队列
"""
self.subscribers.remove(q) self.subscribers.remove(q)
def publish(self, log_entry: str): def publish(self, log_entry: dict):
"""发布消息""" """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
"""
self.log_cache.append(log_entry) self.log_cache.append(log_entry)
for q in self.subscribers: for q in self.subscribers:
try: try:
@@ -71,24 +125,61 @@ class LogBroker:
class LogQueueHandler(logging.Handler): class LogQueueHandler(logging.Handler):
"""日志处理器, 用于将日志消息发送到 LogBroker
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker): def __init__(self, log_broker: LogBroker):
super().__init__() super().__init__()
self.log_broker = log_broker self.log_broker = log_broker
def emit(self, record): def emit(self, record):
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record) log_entry = self.format(record)
self.log_broker.publish(log_entry) self.log_broker.publish(
{
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
}
)
class LogManager: class LogManager:
"""日志管理器, 用于创建和配置日志记录器
提供了获取默认日志记录器logger和设置队列处理器的方法
"""
@classmethod @classmethod
def GetLogger(cls, log_name: str = "default"): def GetLogger(cls, log_name: str = "default"):
"""获取指定名称的日志记录器logger
Args:
log_name (str): 日志记录器的名称, 默认为 "default"
Returns:
logging.Logger: 返回配置好的日志记录器
"""
logger = logging.getLogger(log_name) logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
if logger.hasHandlers(): if logger.hasHandlers():
return logger return logger
console_handler = logging.StreamHandler() # 如果logger没有处理器
console_handler.setLevel(logging.DEBUG) console_handler = logging.StreamHandler(
sys.stdout
) # 创建一个StreamHandler用于控制台输出
console_handler.setLevel(
logging.DEBUG
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
console_formatter = colorlog.ColoredFormatter( console_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s", fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
datefmt="%H:%M:%S", datefmt="%H:%M:%S",
@@ -96,6 +187,8 @@ class LogManager:
) )
class PluginFilter(logging.Filter): class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record): def filter(self, record):
record.plugin_tag = ( record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]" "[Plug]" if is_plugin_path(record.pathname) else "[Core]"
@@ -103,6 +196,9 @@ class LogManager:
return True return True
class FileNameFilter(logging.Filter): class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py # 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record): def filter(self, record):
dirname = os.path.dirname(record.pathname) dirname = os.path.dirname(record.pathname)
@@ -114,22 +210,30 @@ class LogManager:
return True return True
class LevelNameFilter(logging.Filter): class LevelNameFilter(logging.Filter):
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称 # 添加短日志级别名称
def filter(self, record): def filter(self, record):
record.short_levelname = get_short_level_name(record.levelname) record.short_levelname = get_short_level_name(record.levelname)
return True return True
console_handler.setFormatter(console_formatter) console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
logger.addFilter(PluginFilter()) logger.addFilter(PluginFilter()) # 添加插件过滤器
logger.addFilter(FileNameFilter()) logger.addFilter(FileNameFilter()) # 添加文件名过滤器
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器 logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
logger.addHandler(console_handler) logger.addHandler(console_handler) # 添加处理器到logger
return logger return logger
@classmethod @classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
"""
handler = LogQueueHandler(log_broker) handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG) handler.setLevel(logging.DEBUG)
if logger.handlers: if logger.handlers:

View File

@@ -22,13 +22,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
import asyncio
import base64 import base64
import json import json
import os import os
import typing as T import typing as T
import uuid
from enum import Enum from enum import Enum
from pydantic.v1 import BaseModel from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
class ComponentType(Enum): class ComponentType(Enum):
Plain = "Plain" # 纯文本消息 Plain = "Plain" # 纯文本消息
@@ -59,6 +66,8 @@ class ComponentType(Enum):
TTS = "TTS" TTS = "TTS"
Unknown = "Unknown" Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
class BaseMessageComponent(BaseModel): class BaseMessageComponent(BaseModel):
type: ComponentType type: ComponentType
@@ -93,6 +102,10 @@ class BaseMessageComponent(BaseModel):
data[k] = v data[k] = v
return {"type": self.type.lower(), "data": data} return {"type": self.type.lower(), "data": data}
async def to_dict(self) -> dict:
# 默认情况下,回退到旧的同步 toDict()
return self.toDict()
class Plain(BaseMessageComponent): class Plain(BaseMessageComponent):
type: ComponentType = "Plain" type: ComponentType = "Plain"
@@ -109,6 +122,9 @@ class Plain(BaseMessageComponent):
self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;") self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
) )
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
class Face(BaseMessageComponent): class Face(BaseMessageComponent):
type: ComponentType = "Face" type: ComponentType = "Face"
@@ -146,6 +162,76 @@ class Record(BaseMessageComponent):
return Record(file=url, **_) return Record(file=url, **_)
raise Exception("not a valid url") raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if self.file and self.file.startswith("file:///"):
file_path = self.file[8:]
return file_path
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
file_path = self.file
return os.path.abspath(file_path)
else:
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Returns:
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if self.file and self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
async def register_to_file_service(self) -> str:
"""
将语音注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
class Video(BaseMessageComponent): class Video(BaseMessageComponent):
type: ComponentType = "Video" type: ComponentType = "Video"
@@ -156,9 +242,6 @@ class Video(BaseMessageComponent):
path: T.Optional[str] = "" path: T.Optional[str] = ""
def __init__(self, file: str, **_): def __init__(self, file: str, **_):
# for k in _.keys():
# if k == "c" and _[k] not in [2, 3]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_) super().__init__(file=file, **_)
@staticmethod @staticmethod
@@ -171,6 +254,70 @@ class Video(BaseMessageComponent):
return Video(file=url, **_) return Video(file=url, **_)
raise Exception("not a valid url") raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL则会自动进行下载
Returns:
str: 视频的本地路径,以绝对路径表示。
"""
url = self.file
if url and url.startswith("file:///"):
return url[8:]
elif url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
else:
raise Exception(f"download failed: {url}")
elif os.path.exists(url):
return os.path.abspath(url)
else:
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
"""
将视频注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = self.file
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated video file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "video",
"data": {
"file": payload_file,
},
}
class At(BaseMessageComponent): class At(BaseMessageComponent):
type: ComponentType = "At" type: ComponentType = "At"
@@ -180,6 +327,12 @@ class At(BaseMessageComponent):
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
def toDict(self):
return {
"type": "at",
"data": {"qq": str(self.qq)},
}
class AtAll(At): class AtAll(At):
qq: str = "all" qq: str = "all"
@@ -279,10 +432,6 @@ class Image(BaseMessageComponent):
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: T.Optional[str], **_): def __init__(self, file: T.Optional[str], **_):
# for k in _.keys():
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
# (k == "c" and _[k] not in [2, 3]):
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_) super().__init__(file=file, **_)
@staticmethod @staticmethod
@@ -307,14 +456,100 @@ class Image(BaseMessageComponent):
def fromIO(IO): def fromIO(IO):
return Image.fromBytes(IO.read()) return Image.fromBytes(IO.read())
async def convert_to_file_path(self) -> str:
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
image_file_path = url[8:]
return image_file_path
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
image_file_path = url
return os.path.abspath(image_file_path)
else:
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str:
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
Returns:
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
else:
raise Exception(f"not a valid file: {url}")
bs64_data = bs64_data.removeprefix("base64://")
return bs64_data
async def register_to_file_service(self) -> str:
"""
将图片注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
class Reply(BaseMessageComponent): class Reply(BaseMessageComponent):
type: ComponentType = "Reply" type: ComponentType = "Reply"
id: T.Union[str, int] id: T.Union[str, int]
text: T.Optional[str] = "" """所引用的消息 ID"""
qq: T.Optional[int] = 0 chain: T.Optional[T.List["BaseMessageComponent"]] = []
"""被引用的消息段列表"""
sender_id: T.Optional[int] | T.Optional[str] = 0
"""被引用的消息对应的发送者的 ID"""
sender_nickname: T.Optional[str] = ""
"""被引用的消息对应的发送者的昵称"""
time: T.Optional[int] = 0 time: T.Optional[int] = 0
"""被引用的消息发送时间"""
message_str: T.Optional[str] = ""
"""被引用的消息解析后的纯文本消息字符串"""
text: T.Optional[str] = ""
"""deprecated"""
qq: T.Optional[int] = 0
"""deprecated"""
seq: T.Optional[int] = 0 seq: T.Optional[int] = 0
"""deprecated"""
def __init__(self, **_): def __init__(self, **_):
super().__init__(**_) super().__init__(**_)
@@ -352,22 +587,48 @@ class Node(BaseMessageComponent):
type: ComponentType = "Node" type: ComponentType = "Node"
id: T.Optional[int] = 0 # 忽略 id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称 name: T.Optional[str] = "" # qq昵称
uin: T.Optional[int] = 0 # qq号 uin: T.Optional[str] = "0" # qq号
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表 content: T.Optional[list[BaseMessageComponent]] = []
seq: T.Optional[T.Union[str, list]] = "" # 忽略 seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0 time: T.Optional[int] = 0 # 忽略
def __init__(self, content: T.Union[str, list], **_): def __init__(self, content: list[BaseMessageComponent], **_):
if isinstance(content, list): if isinstance(content, Node):
_content = "" # back
for chain in content: content = [content]
_content += chain.toString()
content = _content
super().__init__(content=content, **_) super().__init__(content=content, **_)
def toString(self): async def to_dict(self):
# logger.warn("Protocol: node doesn't support stringify") data_content = []
return "" for comp in self.content:
if isinstance(comp, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await comp.convert_to_base64()
data_content.append(
{
"type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"},
}
)
elif isinstance(comp, File):
# For File segments, we need to handle the file differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, (Node, Nodes)):
# For Node segments, we recursively convert them to dict
d = await comp.to_dict()
data_content.append(d)
else:
d = comp.toDict()
data_content.append(d)
return {
"type": "node",
"data": {
"user_id": str(self.uin),
"nickname": self.name,
"content": data_content,
},
}
class Nodes(BaseMessageComponent): class Nodes(BaseMessageComponent):
@@ -378,7 +639,22 @@ class Nodes(BaseMessageComponent):
super().__init__(nodes=nodes, **_) super().__init__(nodes=nodes, **_)
def toDict(self): def toDict(self):
return {"messages": [node.toDict() for node in self.nodes]} """Deprecated. Use to_dict instead"""
ret = {
"messages": [],
}
for node in self.nodes:
d = node.toDict()
ret["messages"].append(d)
return ret
async def to_dict(self):
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
d = await node.to_dict()
ret["messages"].append(d)
return ret
class Xml(BaseMessageComponent): class Xml(BaseMessageComponent):
@@ -438,15 +714,146 @@ class Unknown(BaseMessageComponent):
class File(BaseMessageComponent): class File(BaseMessageComponent):
""" """
目前此消息段只适配了 Napcat。 文件消息段
""" """
type: ComponentType = "File" type: ComponentType = "File"
name: T.Optional[str] = "" # 名字 name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url本地路径 file_: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
def __init__(self, name: str, file: str): def __init__(self, name: str, file: str = "", url: str = ""):
super().__init__(name=name, file=file) """文件消息段。"""
super().__init__(name=name, file_=file, url=url)
@property
def file(self) -> str:
"""
获取文件路径如果文件不存在但有URL则同步下载文件
Returns:
str: 文件路径
"""
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段"
)
)
return ""
else:
# 等待下载完成
loop.run_until_complete(self._download_file())
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@file.setter
def file(self, value: str):
"""
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args:
value (str): 文件路径或URL
"""
if value.startswith("http://") or value.startswith("https://"):
self.url = value
else:
self.file_ = value
async def get_file(self, allow_return_url: bool = False) -> str:
"""异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间
Args:
allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。
注意,如果为 True也可能返回文件路径。
Returns:
str: 文件路径或者 http 下载链接
"""
if allow_return_url and self.url:
return self.url
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self):
"""
将文件注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.get_file()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = await self.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "file",
"data": {
"name": self.name,
"file": payload_file,
},
}
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
ComponentTypes = { ComponentTypes = {
@@ -477,4 +884,5 @@ ComponentTypes = {
"tts": TTS, "tts": TTS,
"unknown": Unknown, "unknown": Unknown,
"file": File, "file": File,
"WechatEmoji": WechatEmoji,
} }

View File

@@ -1,8 +1,14 @@
import enum import enum
from typing import List, Optional from typing import List, Optional, Union, AsyncGenerator
from dataclasses import dataclass, field from dataclasses import dataclass, field
from astrbot.core.message.components import BaseMessageComponent, Plain, Image from astrbot.core.message.components import (
BaseMessageComponent,
Plain,
Image,
At,
AtAll,
)
from typing_extensions import deprecated from typing_extensions import deprecated
@@ -31,6 +37,30 @@ class MessageChain:
self.chain.append(Plain(message)) self.chain.append(Plain(message))
return self return self
def at(self, name: str, qq: Union[str, int]):
"""添加一条 At 消息到消息链 `chain` 中。
Example:
CommandResult().at("张三", "12345678910")
# 输出 @张三
"""
self.chain.append(At(name=name, qq=qq))
return self
def at_all(self):
"""添加一条 AtAll 消息到消息链 `chain` 中。
Example:
CommandResult().at_all()
# 输出 @所有人
"""
self.chain.append(AtAll())
return self
@deprecated("请使用 message 方法代替。") @deprecated("请使用 message 方法代替。")
def error(self, message: str): def error(self, message: str):
"""添加一条错误消息到消息链 `chain` 中 """添加一条错误消息到消息链 `chain` 中
@@ -77,6 +107,34 @@ class MessageChain:
self.use_t2i_ = use_t2i self.use_t2i_ = use_t2i
return self return self
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
def squash_plain(self):
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
if not self.chain:
return
new_chain = []
first_plain = None
plain_texts = []
for comp in self.chain:
if isinstance(comp, Plain):
if first_plain is None:
first_plain = comp
new_chain.append(comp)
plain_texts.append(comp.text)
else:
new_chain.append(comp)
if first_plain is not None:
first_plain.text = "".join(plain_texts)
self.chain = new_chain
return self
class EventResultType(enum.Enum): class EventResultType(enum.Enum):
"""用于描述事件处理的结果类型。 """用于描述事件处理的结果类型。
@@ -97,6 +155,10 @@ class ResultContentType(enum.Enum):
"""调用 LLM 产生的结果""" """调用 LLM 产生的结果"""
GENERAL_RESULT = enum.auto() GENERAL_RESULT = enum.auto()
"""普通的消息结果""" """普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
STREAMING_FINISH= enum.auto()
"""流式输出完成"""
@dataclass @dataclass
@@ -118,6 +180,9 @@ class MessageEventResult(MessageChain):
default_factory=lambda: ResultContentType.GENERAL_RESULT default_factory=lambda: ResultContentType.GENERAL_RESULT
) )
async_stream: Optional[AsyncGenerator] = None
"""异步流"""
def stop_event(self) -> "MessageEventResult": def stop_event(self) -> "MessageEventResult":
"""终止事件传播。""" """终止事件传播。"""
self.result_type = EventResultType.STOP self.result_type = EventResultType.STOP
@@ -134,6 +199,11 @@ class MessageEventResult(MessageChain):
""" """
return self.result_type == EventResultType.STOP return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
"""设置异步流。"""
self.async_stream = stream
return self
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
"""设置事件处理的结果类型。 """设置事件处理的结果类型。
@@ -147,9 +217,6 @@ class MessageEventResult(MessageChain):
"""是否为 LLM 结果。""" """是否为 LLM 结果。"""
return self.result_content_type == ResultContentType.LLM_RESULT return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
# 为了兼容旧版代码,保留 CommandResult 的别名
CommandResult = MessageEventResult CommandResult = MessageEventResult

View File

@@ -7,16 +7,19 @@ from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage from .respond.stage import RespondStage
# 管道阶段顺序
STAGES_ORDER = [ STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒 "WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单 "WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitStage", # 检查会话是否超过频率限制 "RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全 "ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
"PreProcessStage", # 预处理 "PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用 "ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等 "ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
@@ -28,6 +31,7 @@ __all__ = [
"WhitelistCheckStage", "WhitelistCheckStage",
"RateLimitStage", "RateLimitStage",
"ContentSafetyCheckStage", "ContentSafetyCheckStage",
"PlatformCompatibilityStage",
"PreProcessStage", "PreProcessStage",
"ProcessStage", "ProcessStage",
"ResultDecorateStage", "ResultDecorateStage",

View File

@@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager
@dataclass @dataclass
class PipelineContext: class PipelineContext:
astrbot_config: AstrBotConfig """上下文对象,包含管道执行所需的上下文信息"""
plugin_manager: PluginManager
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象

View File

@@ -0,0 +1,56 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
@register_stage
class PlatformCompatibilityStage(Stage):
"""检查所有处理器的平台兼容性。
这个阶段会检查所有处理器是否在当前平台启用如果未启用则设置platform_compatible属性为False。
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化平台兼容性检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 获取当前平台ID
platform_id = event.get_platform_id()
# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
activated_handlers = []
# 标记不兼容的处理器
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
handler.platform_compatible = False
else:
# 确保处理器为平台兼容状态
handler.platform_compatible = True
# 更新已激活的处理器列表
event.set_extra("activated_handlers", activated_handlers)

View File

@@ -46,28 +46,29 @@ class PreProcessStage(Stage):
stt_provider = ( stt_provider = (
self.plugin_manager.context.provider_manager.curr_stt_provider_inst self.plugin_manager.context.provider_manager.curr_stt_provider_inst
) )
if stt_provider: if not stt_provider:
message_chain = event.get_messages() return
for idx, component in enumerate(message_chain): message_chain = event.get_messages()
if isinstance(component, Record) and component.url: for idx, component in enumerate(message_chain):
path = component.url.removeprefix("file://") if isinstance(component, Record) and component.url:
retry = 5 path = component.url.removeprefix("file://")
for i in range(retry): retry = 5
try: for i in range(retry):
result = await stt_provider.get_text(audio_url=path) try:
if result: result = await stt_provider.get_text(audio_url=path)
logger.info("语音转文本结果: " + result) if result:
message_chain[idx] = Plain(result) logger.info("语音转文本结果: " + result)
event.message_str += result message_chain[idx] = Plain(result)
event.message_obj.message_str += result event.message_str += result
break event.message_obj.message_str += result
except FileNotFoundError as e: break
# napcat workaround except FileNotFoundError as e:
logger.warning(e) # napcat workaround
logger.warning(f"重试中: {i + 1}/{retry}") logger.warning(e)
await asyncio.sleep(0.5) logger.warning(f"重试中: {i + 1}/{retry}")
continue await asyncio.sleep(0.5)
except BaseException as e: continue
logger.error(traceback.format_exc()) except BaseException as e:
logger.error(f"语音转文本失败: {e}") logger.error(traceback.format_exc())
break logger.error(f"语音转文本失败: {e}")
break

View File

@@ -12,13 +12,27 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageEventResult, MessageEventResult,
ResultContentType, ResultContentType,
MessageChain,
) )
from astrbot.core.message.components import Image from astrbot.core.message.components import Image
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.utils.metrics import Metric from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map from astrbot.core.star.star import star_map
from mcp.types import (
TextContent,
ImageContent,
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
)
class LLMRequestSubStage(Stage): class LLMRequestSubStage(Stage):
@@ -28,6 +42,16 @@ class LLMRequestSubStage(Stage):
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][ self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
"wake_prefix" "wake_prefix"
] # str ] # str
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
self.dequeue_context_length = min(
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
self.max_context_length - 1,
) # int
self.streaming_response = ctx.astrbot_config["provider_settings"][
"streaming_response"
] # bool
for bwp in self.bot_wake_prefixs: for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp): if self.provider_wake_prefix.startswith(bwp):
@@ -43,6 +67,10 @@ class LLMRequestSubStage(Stage):
) -> Union[None, AsyncGenerator[None, None]]: ) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None req: ProviderRequest = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
provider = self.ctx.plugin_manager.context.get_using_provider() provider = self.ctx.plugin_manager.context.get_using_provider()
if provider is None: if provider is None:
return return
@@ -54,7 +82,11 @@ class LLMRequestSubStage(Stage):
) )
if req.conversation: if req.conversation:
req.contexts = json.loads(req.conversation.history) all_contexts = json.loads(req.conversation.history)
req.contexts = self._process_tool_message_pairs(
all_contexts, remove_tags=True
)
else: else:
req = ProviderRequest(prompt="", image_urls=[]) req = ProviderRequest(prompt="", image_urls=[])
if self.provider_wake_prefix: if self.provider_wake_prefix:
@@ -64,8 +96,8 @@ class LLMRequestSubStage(Stage):
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message: for comp in event.message_obj.message:
if isinstance(comp, Image): if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file image_path = await comp.convert_to_file_path()
req.image_urls.append(image_url) req.image_urls.append(image_path)
# 获取对话上下文 # 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id( conversation_id = await self.conv_manager.get_curr_conversation_id(
@@ -75,10 +107,16 @@ class LLMRequestSubStage(Stage):
conversation_id = await self.conv_manager.new_conversation( conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin event.unified_msg_origin
) )
req.session_id = event.unified_msg_origin
conversation = await self.conv_manager.get_conversation( conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id event.unified_msg_origin, conversation_id
) )
if not conversation:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
req.conversation = conversation req.conversation = conversation
req.contexts = json.loads(conversation.history) req.contexts = json.loads(conversation.history)
@@ -89,8 +127,10 @@ class LLMRequestSubStage(Stage):
# 执行请求 LLM 前事件钩子。 # 执行请求 LLM 前事件钩子。
# 装饰 system_prompt 等功能 # 装饰 system_prompt 等功能
# 获取当前平台ID
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type( handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMRequestEvent EventType.OnLLMRequestEvent, platform_id=platform_id
) )
for handler in handlers: for handler in handlers:
try: try:
@@ -110,110 +150,373 @@ class LLMRequestSubStage(Stage):
if isinstance(req.contexts, str): if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts) req.contexts = json.loads(req.contexts)
try: # max context length
logger.debug(f"提供商请求 Payload: {req}") if (
if _nested: self.max_context_length != -1 # -1 为不限制
req.func_tool = None # 暂时不支持递归工具调用 and len(req.contexts) // 2 > self.max_context_length
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM ):
logger.debug("上下文长度超过限制,将截断。")
# 执行 LLM 响应后的事件钩子。 req.contexts = req.contexts[
handlers = star_handlers_registry.get_handlers_by_event_type( -(self.max_context_length - self.dequeue_context_length + 1) * 2 :
EventType.OnLLMResponseEvent ]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(req.contexts)
if item.get("role") == "user"
),
None,
) )
for handler in handlers: if index is not None and index > 0:
try: req.contexts = req.contexts[index:]
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" # session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
async def requesting(req: ProviderRequest):
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
final_llm_response = None
if self.streaming_response:
stream = provider.text_chat_stream(**req.__dict__)
async for llm_response in stream:
if llm_response.is_chunk:
if llm_response.result_chain:
yield llm_response.result_chain # MessageChain
else:
yield MessageChain().message(
llm_response.completion_text
)
else:
final_llm_response = llm_response
else:
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM
if not final_llm_response:
raise Exception("LLM response is None.")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
) )
await handler.handler(event, llm_response) for handler in handlers:
except BaseException: try:
logger.error(traceback.format_exc()) logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped(): if event.is_stopped():
logger.info( logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if self.streaming_response:
# 流式输出的处理
async for result in self._handle_llm_stream_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
else:
# 非流式输出的处理
async for result in self._handle_llm_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
) )
return
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
) )
)
if llm_response.role == "assistant": # 保存到历史记录
# text completion await self._save_to_history(event, req, final_llm_response)
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
)
if not self.streaming_response:
event.set_extra("tool_call_result", None)
async for _ in requesting(req):
yield
else:
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting(req))
)
# 这里使用yield来暂停当前阶段等待流式输出完成后继续处理
yield
if event.get_extra("tool_call_result"):
event.set_result(event.get_extra("tool_call_result"))
event.set_extra("tool_call_result", None)
yield
# 暂时直接发出去
if img_b64 := event.get_extra("tool_call_img_respond"):
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
event.set_extra("tool_call_img_respond", None)
yield
async def _handle_llm_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理非流式 LLM 响应。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.LLM_RESULT)
)
else:
event.set_result( event.set_result(
MessageEventResult() MessageEventResult()
.message(llm_response.completion_text) .message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT) .set_result_content_type(ResultContentType.LLM_RESULT)
) )
elif llm_response.role == "err": elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_llm_stream_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理流式 LLM 响应。
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result( event.set_result(
MessageEventResult().message( MessageEventResult(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.STREAMING_FINISH)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.STREAMING_FINISH)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_function_tools(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理函数工具调用。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest表示需要再次调用 LLM
"""
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
) )
) client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
elif llm_response.role == "tool": res = await client.session.call_tool(func_tool.name, func_tool_args)
# function calling if res:
function_calling_result = {} # TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
logger.info( if isinstance(res.content[0], TextContent):
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" tool_call_result.append(
) ToolCallMessageSegment(
for func_tool_name, func_tool_args in zip( role="tool",
llm_response.tools_call_name, llm_response.tools_call_args tool_call_id=func_tool_id,
): content=res.content[0].text,
func_tool = req.func_tool.get_func(func_tool_name) )
)
elif isinstance(res.content[0], ImageContent):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
event.set_extra(
"tool_call_img_respond",
res.content[0].data,
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
event.set_extra(
"tool_call_img_respond",
res.content[0].data,
)
else:
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
else:
# 获取处理器,过滤掉平台不兼容的处理器
platform_id = event.get_platform_id()
star_md = star_map.get(func_tool.handler_module_path)
if (
star_md
and platform_id in star_md.supported_platforms
and not star_md.supported_platforms[platform_id]
):
logger.debug(
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
)
# 直接跳过不添加任何消息到tool_call_result
continue
logger.info( logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}" f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
) )
try: # 尝试调用工具函数
# 尝试调用工具函数 wrapper = self._call_handler(
wrapper = self._call_handler( self.ctx, event, func_tool.handler, **func_tool_args
self.ctx, event, func_tool.handler, **func_tool_args )
) async for resp in wrapper:
async for resp in wrapper: if resp is not None: # 有 return 返回
if resp is not None: # 有 return 返回 tool_call_result.append(
function_calling_result[func_tool_name] = resp ToolCallMessageSegment(
else: role="tool",
yield # 有生成器返回 tool_call_id=func_tool_id,
event.clear_result() # 清除上一个 handler 的结果 content=resp,
except BaseException as e: )
logger.warning(traceback.format_exc()) )
function_calling_result[func_tool_name] = ( else:
"When calling the function, an error occurred: " + str(e) res = event.get_result()
) if res and res.chain:
if function_calling_result: event.set_extra("tool_call_result", res)
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。 yield # 有生成器返回
# 我们重新执行一遍这个 stage event.clear_result() # 清除上一个 handler 的结果
req.func_tool = None # 暂时不支持递归工具调用 except BaseException as e:
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n" logger.warning(traceback.format_exc())
for tool_name, tool_result in function_calling_result.items(): tool_call_result.append(
extra_prompt += ( ToolCallMessageSegment(
f"Tool: {tool_name}\nTool Result: {tool_result}\n" role="tool",
) tool_call_id=func_tool_id,
req.prompt += extra_prompt content=f"error: {str(e)}",
async for _ in self.process(event, _nested=True): )
yield
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
) )
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
) )
return # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
async def _save_to_history( async def _save_to_history(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
@@ -223,9 +526,23 @@ class LLMRequestSubStage(Stage):
if llm_response.role == "assistant": if llm_response.role == "assistant":
# 文本回复 # 文本回复
contexts = req.contexts contexts = req.contexts.copy()
new_record = {"role": "user", "content": req.prompt} contexts.append(await req.assemble_context())
contexts.append(new_record)
# 记录并标记函数调用结果
if req.tool_calls_result:
tool_calls_messages = req.tool_calls_result.to_openai_messages()
# 添加标记
for message in tool_calls_messages:
message["_tool_call_history"] = True
processed_tool_messages = self._process_tool_message_pairs(
tool_calls_messages, remove_tags=False
)
contexts.extend(processed_tool_messages)
contexts.append( contexts.append(
{"role": "assistant", "content": llm_response.completion_text} {"role": "assistant", "content": llm_response.completion_text}
) )
@@ -235,3 +552,59 @@ class LLMRequestSubStage(Stage):
await self.conv_manager.update_conversation( await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
) )
def _process_tool_message_pairs(self, messages, remove_tags=True):
"""处理工具调用消息确保assistant和tool消息成对出现
Args:
messages (list): 消息列表
remove_tags (bool): 是否移除_tool_call_history标记
Returns:
list: 处理后的消息列表保证了assistant和对应tool消息的成对出现
"""
result = []
i = 0
while i < len(messages):
current_msg = messages[i]
# 普通消息直接添加
if "_tool_call_history" not in current_msg:
result.append(current_msg.copy() if remove_tags else current_msg)
i += 1
continue
# 工具调用消息成对处理
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
assistant_msg = current_msg.copy()
if remove_tags and "_tool_call_history" in assistant_msg:
del assistant_msg["_tool_call_history"]
related_tools = []
j = i + 1
while (
j < len(messages)
and messages[j].get("role") == "tool"
and "_tool_call_history" in messages[j]
):
tool_msg = messages[j].copy()
if remove_tags:
del tool_msg["_tool_call_history"]
related_tools.append(tool_msg)
j += 1
# 成对的时候添加到结果
if related_tools:
result.append(assistant_msg)
result.extend(related_tools)
i = j # 跳过已处理
else:
# 单独的tool消息
i += 1
return result

View File

@@ -31,7 +31,18 @@ class StarRequestSubStage(Stage):
) )
if not handlers_parsed_params: if not handlers_parsed_params:
handlers_parsed_params = {} handlers_parsed_params = {}
for handler in activated_handlers: for handler in activated_handlers:
# 检查处理器是否在当前平台兼容
if (
hasattr(handler, "platform_compatible")
and handler.platform_compatible is False
):
logger.debug(
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
)
continue
params = handlers_parsed_params.get(handler.handler_full_name, {}) params = handlers_parsed_params.get(handler.handler_full_name, {})
try: try:
if handler.handler_module_path not in star_map: if handler.handler_module_path not in star_map:

View File

@@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.entities import ProviderRequest
from astrbot.core import logger from astrbot.core import logger

View File

@@ -58,33 +58,30 @@ class RateLimitStage(Stage):
now = datetime.now() now = datetime.now()
async with self.locks[session_id]: # 确保同一会话不会并发修改队列 async with self.locks[session_id]: # 确保同一会话不会并发修改队列
timestamps = self.event_timestamps[session_id] # 检查并处理限流,可能需要多次检查直到满足条件
while True:
timestamps = self.event_timestamps[session_id]
self._remove_expired_timestamps(timestamps, now)
self._remove_expired_timestamps(timestamps, now) if len(timestamps) < self.rate_limit_count:
timestamps.append(now)
break
else:
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds() + 0.3
if len(timestamps) >= self.rate_limit_count: match self.rl_strategy:
# 达到限流阈值,计算下一个窗口的时间 case RateLimitStrategy.STALL.value:
next_window_time = timestamps[0] + self.rate_limit_time logger.info(
stall_duration = (next_window_time - now).total_seconds() f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
)
match self.rl_strategy: await asyncio.sleep(stall_duration)
case RateLimitStrategy.STALL.value: now = datetime.now()
logger.info( case RateLimitStrategy.DISCARD.value:
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。" logger.info(
) f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
await asyncio.sleep(stall_duration) )
case RateLimitStrategy.DISCARD.value: return event.stop_event()
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
logger.info(
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
)
return event.stop_event()
self._remove_expired_timestamps(
timestamps, now + timedelta(seconds=stall_duration)
)
timestamps.append(now)
def _remove_expired_timestamps( def _remove_expired_timestamps(
self, timestamps: Deque[datetime], now: datetime self, timestamps: Deque[datetime], now: datetime

View File

@@ -2,22 +2,42 @@ import random
import asyncio import asyncio
import math import math
import traceback import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage from ..stage import register_stage, Stage
from ..context import PipelineContext from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map from astrbot.core.star.star import star_map
from astrbot.core.message.components import Plain, Reply, At from astrbot.core.utils.path_util import path_Mapping
@register_stage @register_stage
class RespondStage(Stage): class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(
comp.text and comp.text.strip()
), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
Comp.Video: lambda comp: bool(comp.file), # 视频
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
}
async def initialize(self, ctx: PipelineContext): async def initialize(self, ctx: PipelineContext):
self.ctx = ctx self.ctx = ctx
self.config = ctx.astrbot_config
self.platform_settings: dict = self.config.get("platform_settings", {})
self.reply_with_mention = ctx.astrbot_config["platform_settings"][ self.reply_with_mention = ctx.astrbot_config["platform_settings"][
"reply_with_mention" "reply_with_mention"
@@ -62,7 +82,7 @@ class RespondStage(Stage):
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
"""分段回复 计算间隔时间""" """分段回复 计算间隔时间"""
if self.interval_method == "log": if self.interval_method == "log":
if isinstance(comp, Plain): if isinstance(comp, Comp.Plain):
wc = await self._word_cnt(comp.text) wc = await self._word_cnt(comp.text)
i = math.log(wc + 1, self.log_base) i = math.log(wc + 1, self.log_base)
return random.uniform(i, i + 0.5) return random.uniform(i, i + 0.5)
@@ -72,15 +92,70 @@ class RespondStage(Stage):
# random # random
return random.uniform(self.interval[0], self.interval[1]) return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
"""检查消息链是否为空
Args:
chain (list[BaseMessageComponent]): 包含消息对象的列表
"""
if not chain:
return True
for comp in chain:
comp_type = type(comp)
# 检查组件类型是否在字典中
if comp_type in self._component_validators:
if self._component_validators[comp_type](comp):
return False
# 如果所有组件都为空
return True
async def process( async def process(
self, event: AstrMessageEvent self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]: ) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result() result = event.get_result()
if result is None: if result is None:
return return
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
if len(result.chain) > 0: if result.result_content_type == ResultContentType.STREAMING_RESULT:
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send() await event._pre_send()
await event.send_streaming(result.async_stream, use_fallback)
await event._post_send()
return
elif len(result.chain) > 0:
# 检查路径映射
if mappings := self.platform_settings.get("path_mapping", []):
for idx, component in enumerate(result.chain):
if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
await event._pre_send()
# 检查消息链是否为空
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
non_record_comps = [
c for c in result.chain if not isinstance(c, Comp.Record)
]
if self.enable_seg and ( if self.enable_seg and (
(self.only_llm_result and result.is_llm_result()) (self.only_llm_result and result.is_llm_result())
@@ -89,30 +164,55 @@ class RespondStage(Stage):
decorated_comps = [] decorated_comps = []
if self.reply_with_mention: if self.reply_with_mention:
for comp in result.chain: for comp in result.chain:
if isinstance(comp, At): if isinstance(comp, Comp.At):
decorated_comps.append(comp) decorated_comps.append(comp)
result.chain.remove(comp) result.chain.remove(comp)
break break
if self.reply_with_quote: if self.reply_with_quote:
for comp in result.chain: for comp in result.chain:
if isinstance(comp, Reply): if isinstance(comp, Comp.Reply):
decorated_comps.append(comp) decorated_comps.append(comp)
result.chain.remove(comp) result.chain.remove(comp)
break break
for rcomp in record_comps:
i = await self._calc_comp_interval(rcomp)
await asyncio.sleep(i)
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
# 分段回复 # 分段回复
for comp in result.chain: for comp in non_record_comps:
i = await self._calc_comp_interval(comp) i = await self._calc_comp_interval(comp)
await asyncio.sleep(i) await asyncio.sleep(i)
await event.send(MessageChain([*decorated_comps, comp])) try:
await event.send(MessageChain([*decorated_comps, comp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
else: else:
await event.send(result) for rcomp in record_comps:
try:
await event.send(MessageChain([rcomp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
try:
await event.send(MessageChain(non_record_comps))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"发送消息失败: {e} chain: {result.chain}")
await event._post_send() await event._post_send()
logger.info( logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}" f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
) )
handlers = star_handlers_registry.get_handlers_by_event_type( handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
) )
for handler in handlers: for handler in handlers:
try: try:

View File

@@ -1,16 +1,18 @@
import time
import re import re
import time
import traceback import traceback
from typing import Union, AsyncGenerator from typing import AsyncGenerator, Union
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext from astrbot.core import html_renderer, logger, file_token_service
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage, registered_stages
@register_stage @register_stage
@@ -31,6 +33,8 @@ class ResultDecorateStage(Stage):
self.t2i_word_threshold = 50 self.t2i_word_threshold = 50
except BaseException: except BaseException:
self.t2i_word_threshold = 150 self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
self.t2i_use_network = self.t2i_strategy == "remote"
self.forward_threshold = ctx.astrbot_config["platform_settings"][ self.forward_threshold = ctx.astrbot_config["platform_settings"][
"forward_threshold" "forward_threshold"
@@ -70,11 +74,17 @@ class ResultDecorateStage(Stage):
if result is None or not result.chain: if result is None or not result.chain:
return return
if result.result_content_type == ResultContentType.STREAMING_RESULT:
return
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
# 回复时检查内容安全 # 回复时检查内容安全
if ( if (
self.content_safe_check_reply self.content_safe_check_reply
and self.content_safe_check_stage and self.content_safe_check_stage
and result.is_llm_result() and result.is_llm_result()
and not is_stream # 流式输出不检查内容安全
): ):
text = "" text = ""
for comp in result.chain: for comp in result.chain:
@@ -87,13 +97,17 @@ class ResultDecorateStage(Stage):
# 发送消息前事件钩子 # 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type( handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
) )
for handler in handlers: for handler in handlers:
try: try:
logger.debug( logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
) )
if is_stream:
logger.warning(
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
)
await handler.handler(event) await handler.handler(event)
if event.get_result() is None or not event.get_result().chain: if event.get_result() is None or not event.get_result().chain:
logger.debug( logger.debug(
@@ -108,6 +122,11 @@ class ResultDecorateStage(Stage):
) )
return return
# 流式输出不执行下面的逻辑
if is_stream:
logger.info("流式输出已启用,跳过结果装饰阶段")
return
# 需要再获取一次。插件可能直接对 chain 进行了替换。 # 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result() result = event.get_result()
if result is None: if result is None:
@@ -133,9 +152,9 @@ class ResultDecorateStage(Stage):
# 不分段回复 # 不分段回复
new_chain.append(comp) new_chain.append(comp)
continue continue
split_response = [] split_response = re.findall(
for line in comp.text.split("\n"): self.regex, comp.text, re.DOTALL | re.MULTILINE
split_response.extend(re.findall(self.regex, line)) )
if not split_response: if not split_response:
new_chain.append(comp) new_chain.append(comp)
continue continue
@@ -150,28 +169,55 @@ class ResultDecorateStage(Stage):
result.chain = new_chain result.chain = new_chain
# TTS # TTS
tts_provider = (
self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
)
if ( if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"] self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result() and result.is_llm_result()
and tts_provider
): ):
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = [] new_chain = []
for comp in result.chain: for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1: if isinstance(comp, Plain) and len(comp.text) > 1:
try: try:
logger.info("TTS 请求: " + comp.text) logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text) audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path) logger.info(f"TTS 结果: {audio_path}")
if audio_path: if not audio_path:
new_chain.append(
Record(file=audio_path, url=audio_path)
)
else:
logger.error( logger.error(
f"由于 TTS 音频文件找到,消息段转语音失败: {comp.text}" f"由于 TTS 音频文件找到,消息段转语音失败: {comp.text}"
) )
new_chain.append(comp) new_chain.append(comp)
except BaseException: continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。") logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp) new_chain.append(comp)
@@ -192,7 +238,9 @@ class ResultDecorateStage(Stage):
if plain_str and len(plain_str) > self.t2i_word_threshold: if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time() render_start = time.time()
try: try:
url = await html_renderer.render_t2i(plain_str, return_url=True) url = await html_renderer.render_t2i(
plain_str, return_url=True, use_network=self.t2i_use_network
)
except BaseException: except BaseException:
logger.error("文本转图片失败,使用文本发送。") logger.error("文本转图片失败,使用文本发送。")
return return
@@ -201,7 +249,18 @@ class ResultDecorateStage(Stage):
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。" "文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
) )
if url: if url:
result.chain = [Image.fromURL(url)] if url.startswith("http"):
result.chain = [Image.fromURL(url)]
elif (
self.ctx.astrbot_config["t2i_use_file_service"]
and self.ctx.astrbot_config["callback_api_base"]
):
token = await file_token_service.register_file(url)
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
logger.debug(f"已注册:{url}")
result.chain = [Image.fromURL(url)]
else:
result.chain = [Image.fromFileSystem(url)]
# 触发转发消息 # 触发转发消息
has_forwarded = False has_forwarded = False

View File

@@ -7,49 +7,72 @@ from astrbot.core import logger
class PipelineScheduler: class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext): def __init__(self, context: PipelineContext):
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__)) registered_stages.sort(
self.ctx = context key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
async def initialize(self): async def initialize(self):
"""初始化管道调度器时, 初始化所有阶段"""
for stage in registered_stages: for stage in registered_stages:
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}") # logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx) await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0): async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
"""依次执行各个阶段
Args:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(registered_stages)): for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i] stage = registered_stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__ .__name__}") # logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coro = stage.process(event) coroutine = stage.process(
if isinstance(coro, AsyncGenerator): event
async for _ in coro: ) # 调用阶段的process方法, 返回协程或者异步生成器
if isinstance(coroutine, AsyncGenerator):
# 如果返回的是异步生成器, 实现洋葱模型的核心
async for _ in coroutine:
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
if event.is_stopped(): if event.is_stopped():
logger.debug( logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。" f"阶段 {stage.__class__.__name__} 已终止事件传播。"
) )
break break
# 递归调用, 处理所有后续阶段
await self._process_stages(event, i + 1) await self._process_stages(event, i + 1)
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
if event.is_stopped(): if event.is_stopped():
logger.debug( logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。" f"阶段 {stage.__class__.__name__} 已终止事件传播。"
) )
break break
else: else:
await coro # 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
# 简单地等待它执行完成, 然后继续执行下一个阶段
await coroutine
if event.is_stopped(): if event.is_stopped():
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break break
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break
async def execute(self, event: AstrMessageEvent): async def execute(self, event: AstrMessageEvent):
"""执行 pipeline""" """执行 pipeline
Args:
event (AstrMessageEvent): 事件对象
"""
await self._process_stages(event) await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if not event._has_send_oper and event.get_platform_name() == "webchat": if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None) await event.send(None)

View File

@@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import inspect import inspect
import traceback
from astrbot.api import logger from astrbot.api import logger
from typing import List, AsyncGenerator, Union, Awaitable from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = [] registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
"""维护了所有已注册的 Stage 实现类"""
def register_stage(cls): def register_stage(cls):
@@ -22,14 +22,24 @@ class Stage(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def initialize(self, ctx: PipelineContext) -> None: async def initialize(self, ctx: PipelineContext) -> None:
"""初始化阶段""" """初始化阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def process( async def process(
self, event: AstrMessageEvent self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]: ) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件""" """处理事件
Args:
event (AstrMessageEvent): 事件对象,包含事件的相关信息
Returns:
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
"""
raise NotImplementedError raise NotImplementedError
async def _call_handler( async def _call_handler(
@@ -40,33 +50,61 @@ class Stage(abc.ABC):
*args, *args,
**kwargs, **kwargs,
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
"""调用 Handler。""" """执行事件处理函数并处理其返回结果
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型每次yield都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 待处理的事件对象
handler (Awaitable): 事件处理函数
*args: 传递给handler的位置参数
**kwargs: 传递给handler的关键字参数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器(async def)
trace_ = None
try: try:
ready_to_call = handler(event, *args, **kwargs) ready_to_call = handler(event, *args, **kwargs)
except TypeError as e: except TypeError as _:
# 向下兼容 # 向下兼容
logger.debug(str(e)) trace_ = traceback.format_exc()
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs) ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator): if isinstance(ready_to_call, AsyncGenerator):
_has_yielded = False # 如果是一个异步生成器, 进入洋葱模型
async for ret in ready_to_call: _has_yielded = False # 是否返回过值
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None无返回值 try:
_has_yielded = True async for ret in ready_to_call:
if isinstance(ret, (MessageEventResult, CommandResult)): # 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
event.set_result(ret) # 返回值只能是 MessageEventResult 或者 None无返回值
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield # 传递控制权给上一层的process函数
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret # 传递控制权给上一层的process函数
if not _has_yielded:
# 如果这个异步生成器没有执行到yield分支
yield yield
else: except Exception as e:
yield ret logger.error(f"Previous Error: {trace_}")
if not _has_yielded: raise e
yield
elif inspect.iscoroutine(ready_to_call): elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine # 如果只是一个协程, 直接执行
ret = await ready_to_call ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)): if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret) event.set_result(ret)
yield yield # 传递控制权给上一层的process函数
else: else:
yield ret yield ret # 传递控制权给上一层的process函数

View File

@@ -1,5 +1,6 @@
from ..stage import Stage, register_stage from ..stage import Stage, register_stage
from ..context import PipelineContext from ..context import PipelineContext
from astrbot import logger
from typing import Union, AsyncGenerator from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
@@ -21,18 +22,38 @@ class WakingCheckStage(Stage):
""" """
async def initialize(self, ctx: PipelineContext) -> None: async def initialize(self, ctx: PipelineContext) -> None:
"""初始化唤醒检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get( self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True "no_permission_reply", True
) )
# 私聊是否需要 wake_prefix 才能唤醒机器人
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
"platform_settings"
].get("friend_message_needs_wake_prefix", False)
# 是否忽略机器人自己发送的消息
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
"ignore_bot_self_message", False
)
async def process( async def process(
self, event: AstrMessageEvent self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]: ) -> Union[None, AsyncGenerator[None, None]]:
if (
self.ignore_bot_self_message
and event.get_self_id() == event.get_sender_id()
):
# 忽略机器人自己发送的消息
event.stop_event()
return
# 设置 sender 身份 # 设置 sender 身份
event.message_str = event.message_str.strip() event.message_str = event.message_str.strip()
for admin_id in self.ctx.astrbot_config["admins_id"]: for admin_id in self.ctx.astrbot_config["admins_id"]:
if event.get_sender_id() == admin_id: if str(event.get_sender_id()) == admin_id:
event.role = "admin" event.role = "admin"
break break
@@ -68,7 +89,7 @@ class WakingCheckStage(Stage):
event.is_at_or_wake_command = True event.is_at_or_wake_command = True
break break
# 检查是否是私聊 # 检查是否是私聊
if event.is_private_chat(): if event.is_private_chat() and not self.friend_message_needs_wake_prefix:
is_wake = True is_wake = True
event.is_wake = True event.is_wake = True
event.is_at_or_wake_command = True event.is_at_or_wake_command = True
@@ -84,6 +105,7 @@ class WakingCheckStage(Stage):
# filter 需满足 AND 逻辑关系 # filter 需满足 AND 逻辑关系
passed = True passed = True
permission_not_pass = False permission_not_pass = False
permission_filter_raise_error = False
if len(handler.event_filters) == 0: if len(handler.event_filters) == 0:
continue continue
@@ -92,6 +114,7 @@ class WakingCheckStage(Stage):
if isinstance(filter, PermissionTypeFilter): if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config): if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True permission_not_pass = True
permission_filter_raise_error = filter.raise_error
else: else:
if not filter.filter(event, self.ctx.astrbot_config): if not filter.filter(event, self.ctx.astrbot_config):
passed = False passed = False
@@ -102,17 +125,25 @@ class WakingCheckStage(Stage):
f"插件 {star_map[handler.handler_module_path].name}: {e}" f"插件 {star_map[handler.handler_module_path].name}: {e}"
) )
) )
await event._post_send()
event.stop_event() event.stop_event()
passed = False passed = False
break break
if passed: if passed:
if permission_not_pass: if permission_not_pass:
if not permission_filter_raise_error:
# 跳过
continue
if self.no_permission_reply: if self.no_permission_reply:
await event.send( await event.send(
MessageChain().message( MessageChain().message(
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。" f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
) )
) )
await event._post_send()
logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
)
event.stop_event() event.stop_event()
return return

View File

@@ -15,6 +15,9 @@ class WhitelistCheckStage(Stage):
"enable_id_white_list" "enable_id_white_list"
] ]
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"] self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
self.whitelist = [
str(i).strip() for i in self.whitelist if str(i).strip() != ""
]
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][ self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
"wl_ignore_admin_on_group" "wl_ignore_admin_on_group"
] ]
@@ -51,7 +54,10 @@ class WhitelistCheckStage(Stage):
and event.get_message_type() == MessageType.FRIEND_MESSAGE and event.get_message_type() == MessageType.FRIEND_MESSAGE
): ):
return return
if event.unified_msg_origin not in self.whitelist: if (
event.unified_msg_origin not in self.whitelist
and str(event.get_group_id()).strip() not in self.whitelist
):
if self.wl_log: if self.wl_log:
logger.info( logger.info(
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。" f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"

View File

@@ -1,7 +1,7 @@
from .platform import Platform from .platform import Platform
from .astr_message_event import AstrMessageEvent from .astr_message_event import AstrMessageEvent
from .platform_metadata import PlatformMetadata from .platform_metadata import PlatformMetadata
from .astrbot_message import AstrBotMessage, MessageMember, MessageType from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group
__all__ = [ __all__ = [
"Platform", "Platform",
@@ -10,4 +10,5 @@ __all__ = [
"AstrBotMessage", "AstrBotMessage",
"MessageMember", "MessageMember",
"MessageType", "MessageType",
"Group",
] ]

View File

@@ -1,11 +1,12 @@
import abc import abc
import asyncio import asyncio
import re
import hashlib
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from .astrbot_message import AstrBotMessage from typing import List, Union, Optional, AsyncGenerator
from .platform_metadata import PlatformMetadata
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain from astrbot.core.db.po import Conversation
from astrbot.core.platform.message_type import MessageType
from typing import List, Union
from astrbot.core.message.components import ( from astrbot.core.message.components import (
Plain, Plain,
Image, Image,
@@ -14,10 +15,14 @@ from astrbot.core.message.components import (
At, At,
AtAll, AtAll,
Forward, Forward,
Reply,
) )
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest from .astrbot_message import AstrBotMessage, Group
from astrbot.core.db.po import Conversation from .platform_metadata import PlatformMetadata
@dataclass @dataclass
@@ -79,6 +84,9 @@ class AstrMessageEvent(abc.ABC):
def get_platform_name(self): def get_platform_name(self):
return self.platform_meta.name return self.platform_meta.name
def get_platform_id(self):
return self.platform_meta.id
def get_message_str(self) -> str: def get_message_str(self) -> str:
""" """
获取消息字符串。 获取消息字符串。
@@ -101,8 +109,15 @@ class AstrMessageEvent(abc.ABC):
elif isinstance(i, Forward): elif isinstance(i, Forward):
# 转发消息 # 转发消息
outline += "[转发消息]" outline += "[转发消息]"
elif isinstance(i, Reply):
# 引用回复
if i.message_str:
outline += f"[引用消息({i.sender_nickname}: {i.message_str})]"
else:
outline += "[引用消息]"
else: else:
outline += f"[{i.type}]" outline += f"[{i.type}]"
outline += " "
return outline return outline
def get_message_outline(self) -> str: def get_message_outline(self) -> str:
@@ -193,9 +208,26 @@ class AstrMessageEvent(abc.ABC):
""" """
return self.role == "admin" return self.role == "admin"
async def send(self, message: MessageChain): async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
""" """
发送消息到消息平台 将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台作为不支持流式输出平台的Fallback
"""
while True:
match = re.search(pattern, buffer)
if not match:
break
matched_text = match.group()
await self.send(MessageChain([Plain(matched_text)]))
buffer = buffer[match.end() :]
await asyncio.sleep(1.5) # 限速
return buffer
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊。
Fallback仅支持 aiocqhttp, gewechat。
""" """
asyncio.create_task( asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
@@ -363,3 +395,31 @@ class AstrMessageEvent(abc.ABC):
system_prompt=system_prompt, system_prompt=system_prompt,
conversation=conversation, conversation=conversation,
) )
"""平台适配器"""
async def send(self, message: MessageChain):
"""发送消息到消息平台。
Args:
message (MessageChain): 消息链,具体使用方式请参考文档。
"""
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
sid = str(uuid.UUID(bytes=hash_obj.digest()))
asyncio.create_task(
Metric.upload(
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
)
)
self._has_send_oper = True
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息返回当前群聊的数据。
适配情况:
- gewechat
- aiocqhttp(OneBotv11)
"""
...

View File

@@ -10,6 +10,41 @@ class MessageMember:
user_id: str # 发送者id user_id: str # 发送者id
nickname: str = None nickname: str = None
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"User ID: {self.user_id},"
f"Nickname: {self.nickname if self.nickname else 'N/A'}"
)
@dataclass
class Group:
group_id: str
"""群号"""
group_name: str = None
"""群名称"""
group_avatar: str = None
"""群头像"""
group_owner: str = None
"""群主 id"""
group_admins: List[str] = None
"""群管理员 id"""
members: List[MessageMember] = None
"""所有群成员"""
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"Group ID: {self.group_id}\n"
f"Name: {self.group_name if self.group_name else 'N/A'}\n"
f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n"
f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n"
f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n"
f"Members Len: {len(self.members) if self.members else 0}\n"
f"First Member: {self.members[0] if self.members else 'N/A'}\n"
)
class AstrBotMessage: class AstrBotMessage:
""" """

View File

@@ -62,12 +62,22 @@ class PlatformManager:
from .sources.gewechat.gewechat_platform_adapter import ( from .sources.gewechat.gewechat_platform_adapter import (
GewechatPlatformAdapter, # noqa: F401 GewechatPlatformAdapter, # noqa: F401
) )
case "wechatpadpro":
from .sources.wechatpadpro.wechatpadpro_adapter import (
WeChatPadProAdapter, # noqa: F401
)
case "lark": case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401 from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram": case "telegram":
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401 from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
case "wecom": case "wecom":
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401 from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
case "weixin_official_account":
from .sources.weixin_official_account.weixin_offacc_adapter import WeixinOfficialAccountPlatformAdapter # noqa
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.error( logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。" f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
@@ -81,14 +91,18 @@ class PlatformManager:
) )
return return
cls_type = platform_cls_map[platform_config["type"]] cls_type = platform_cls_map[platform_config["type"]]
inst = cls_type(platform_config, self.settings, self.event_queue) inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config["id"]] = inst self._inst_map[platform_config["id"]] = {
"inst": inst,
"client_id": inst.client_self_id,
}
self.platform_insts.append(inst) self.platform_insts.append(inst)
asyncio.create_task( asyncio.create_task(
self._task_wrapper( self._task_wrapper(
asyncio.create_task( asyncio.create_task(
inst.run(), name=platform_config["id"] + "_platform" inst.run(),
name=f"platform_{platform_config['type']}_{platform_config['id']}",
) )
) )
) )
@@ -105,38 +119,42 @@ class PlatformManager:
logger.error("-------") logger.error("-------")
async def reload(self, platform_config: dict): async def reload(self, platform_config: dict):
# 还未实现完成,不要调用此方法 await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
if platform_config["id"] in self._inst_map:
# 正在运行
if getattr(self._inst_map[platform_config["id"]], "terminate", None):
logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
await self._inst_map[platform_config["id"]].terminate()
logger.info(f"{platform_config['id']} 平台适配器已终止。")
del self._inst_map[platform_config["id"]]
self.platform_insts.remove(self._inst_map[platform_config["id"]])
else:
logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
# 再启动新的实例
await self.load_platform(platform_config) await self.load_platform(platform_config)
else: # 和配置文件保持同步
# 先将 _inst_map 中在 platform_config 中不存在的实例删除 config_ids = [provider["id"] for provider in self.platforms_config]
config_ids = [platform["id"] for platform in self.platforms_config] for key in list(self._inst_map.keys()):
for key in list(self._inst_map.keys()): if key not in config_ids:
if key not in config_ids: await self.terminate_platform(key)
if getattr(self._inst_map[key], "terminate", None):
logger.info(f"正在尝试终止 {key} 平台适配器 ...")
await self._inst_map[key].terminate()
logger.info(f"{key} 平台适配器已终止。")
del self._inst_map[key]
self.platform_insts.remove(self._inst_map[key])
else:
logger.warning(f"可能无法正常终止 {key} 平台适配器。")
# 再启动新的实例 async def terminate_platform(self, platform_id: str):
await self.load_platform(platform_config) if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id, None)
client_id = info["client_id"]
inst = info["inst"]
try:
self.platform_insts.remove(
next(
inst
for inst in self.platform_insts
if inst.client_self_id == client_id
)
)
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
if getattr(inst, "terminate", None):
await inst.terminate()
async def terminate(self):
for inst in self.platform_insts:
if getattr(inst, "terminate", None):
await inst.terminate()
def get_insts(self): def get_insts(self):
return self.platform_insts return self.platform_insts

View File

@@ -1,4 +1,5 @@
import abc import abc
import uuid
from typing import Awaitable, Any from typing import Awaitable, Any
from asyncio import Queue from asyncio import Queue
from .platform_metadata import PlatformMetadata from .platform_metadata import PlatformMetadata
@@ -13,6 +14,7 @@ class Platform(abc.ABC):
super().__init__() super().__init__()
# 维护了消息平台的事件队列EventBus 会从这里取出事件并处理。 # 维护了消息平台的事件队列EventBus 会从这里取出事件并处理。
self._event_queue = event_queue self._event_queue = event_queue
self.client_self_id = uuid.uuid4().hex
@abc.abstractmethod @abc.abstractmethod
def run(self) -> Awaitable[Any]: def run(self) -> Awaitable[Any]:
@@ -25,7 +27,7 @@ class Platform(abc.ABC):
""" """
终止一个平台的运行实例。 终止一个平台的运行实例。
""" """
pass ...
@abc.abstractmethod @abc.abstractmethod
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:

View File

@@ -7,6 +7,8 @@ class PlatformMetadata:
"""平台的名称""" """平台的名称"""
description: str description: str
"""平台的描述""" """平台的描述"""
id: str = None
"""平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict = None default_config_tmpl: dict = None
"""平台的默认配置模板""" """平台的默认配置模板"""

View File

@@ -1,9 +1,19 @@
import asyncio import asyncio
import re
from astrbot.api.event import AstrMessageEvent, MessageChain from typing import AsyncGenerator, Dict, List
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
from aiocqhttp import CQHttp from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
Image,
Node,
Nodes,
Plain,
Record,
Video,
File,
BaseMessageComponent,
)
from astrbot.api.platform import Group, MessageMember
class AiocqhttpMessageEvent(AstrMessageEvent): class AiocqhttpMessageEvent(AstrMessageEvent):
@@ -13,51 +23,57 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id) super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot self.bot = bot
@staticmethod
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
"""修复部分字段"""
if isinstance(segment, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await segment.convert_to_base64()
return {
"type": segment.type.lower(),
"data": {
"file": f"base64://{bs64}",
},
}
elif isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
return d
elif isinstance(segment, Video):
d = await segment.to_dict()
return d
else:
# For other segments, we simply convert them to a dict by calling toDict
return segment.toDict()
@staticmethod @staticmethod
async def _parse_onebot_json(message_chain: MessageChain): async def _parse_onebot_json(message_chain: MessageChain):
"""解析成 OneBot json 格式""" """解析成 OneBot json 格式"""
ret = [] ret = []
for segment in message_chain.chain: for segment in message_chain.chain:
d = segment.toDict()
if isinstance(segment, Plain): if isinstance(segment, Plain):
d["type"] = "text" if not segment.text.strip():
elif isinstance(segment, (Image, Record)): continue
# convert to base64 d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
if segment.file and segment.file.startswith("file:///"):
bs64_data = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
bs64_data = file_to_base64(image_file_path)
elif segment.file and segment.file.startswith("base64://"):
bs64_data = segment.file
else:
bs64_data = file_to_base64(segment.file)
d["data"] = {
"file": bs64_data,
}
elif isinstance(segment, At):
d["data"] = {
"qq": str(segment.qq) # 转换为字符串
}
ret.append(d) ret.append(d)
return ret return ret
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message) # 转发消息、文件消息不能和普通消息混在一起发送
send_one_by_one = any(
send_one_by_one = False isinstance(seg, (Node, Nodes, File)) for seg in message.chain
for seg in message.chain: )
if isinstance(seg, (Node, Nodes)):
# 转发消息不能和普通消息混在一起发送
send_one_by_one = True
break
if send_one_by_one: if send_one_by_one:
for seg in message.chain: for seg in message.chain:
if isinstance(seg, Nodes): if isinstance(seg, (Node, Nodes)):
# 带有多个节点的合并转发消息 # 合并转发消息
payload = seg.toDict()
if isinstance(seg, Node):
nodes = Nodes([seg])
seg = nodes
payload = await seg.to_dict()
if self.get_group_id(): if self.get_group_id():
payload["group_id"] = self.get_group_id() payload["group_id"] = self.get_group_id()
await self.bot.call_action("send_group_forward_msg", **payload) await self.bot.call_action("send_group_forward_msg", **payload)
@@ -66,6 +82,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await self.bot.call_action( await self.bot.call_action(
"send_private_forward_msg", **payload "send_private_forward_msg", **payload
) )
elif isinstance(seg, File):
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
await self.bot.send(
self.message_obj.raw_message,
[d],
)
else: else:
await self.bot.send( await self.bot.send(
self.message_obj.raw_message, self.message_obj.raw_message,
@@ -75,6 +97,86 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
) )
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
else: else:
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if not ret:
return
await self.bot.send(self.message_obj.raw_message, ret) await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message) await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)
async def get_group(self, group_id=None, **kwargs):
if isinstance(group_id, str) and group_id.isdigit():
group_id = int(group_id)
elif self.get_group_id():
group_id = int(self.get_group_id())
else:
return None
info: dict = await self.bot.call_action(
"get_group_info",
group_id=group_id,
)
members: List[Dict] = await self.bot.call_action(
"get_group_member_list",
group_id=group_id,
)
owner_id = None
admin_ids = []
for member in members:
if member["role"] == "owner":
owner_id = member["user_id"]
if member["role"] == "admin":
admin_ids.append(member["user_id"])
group = Group(
group_id=str(group_id),
group_name=info.get("group_name"),
group_avatar="",
group_admins=admin_ids,
group_owner=str(owner_id),
members=[
MessageMember(
user_id=member["user_id"],
nickname=member.get("nickname") or member.get("card"),
)
for member in members
],
)
return group

View File

@@ -1,8 +1,8 @@
import os
import time import time
import asyncio import asyncio
import logging import logging
import uuid import uuid
import itertools
from typing import Awaitable, Any from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event from aiocqhttp import CQHttp, Event
from astrbot.api.platform import ( from astrbot.api.platform import (
@@ -20,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed from aiocqhttp.exceptions import ActionFailed
from astrbot.core.utils.io import download_file
@register_platform_adapter( @register_platform_adapter(
@@ -39,14 +38,18 @@ class AiocqhttpAdapter(Platform):
self.port = platform_config["ws_reverse_port"] self.port = platform_config["ws_reverse_port"]
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
"aiocqhttp", name="aiocqhttp",
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
) )
self.stop = False
self.bot = CQHttp( self.bot = CQHttp(
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180 use_ws_reverse=True,
import_name="aiocqhttp",
api_timeout_sec=180,
access_token=platform_config.get(
"ws_reverse_token"
), # 以防旧版本配置不存在
) )
@self.bot.on_request() @self.bot.on_request()
@@ -100,6 +103,9 @@ class AiocqhttpAdapter(Platform):
if event["post_type"] == "message": if event["post_type"] == "message":
abm = await self._convert_handle_message_event(event) abm = await self._convert_handle_message_event(event)
if abm.sender.user_id == "2854196310":
# 屏蔽 QQ 管家的消息
return
elif event["post_type"] == "notice": elif event["post_type"] == "notice":
abm = await self._convert_handle_notice_event(event) abm = await self._convert_handle_notice_event(event)
elif event["post_type"] == "request": elif event["post_type"] == "request":
@@ -111,7 +117,7 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件""" """OneBot V11 请求类事件"""
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id) abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.type = MessageType.OTHER_MESSAGE abm.type = MessageType.OTHER_MESSAGE
if "group_id" in event and event["group_id"]: if "group_id" in event and event["group_id"]:
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
@@ -120,6 +126,12 @@ class AiocqhttpAdapter(Platform):
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE: if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id) abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_str = "" abm.message_str = ""
abm.message = [] abm.message = []
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
@@ -131,7 +143,7 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 通知类事件""" """OneBot V11 通知类事件"""
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id) abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.type = MessageType.OTHER_MESSAGE abm.type = MessageType.OTHER_MESSAGE
if "group_id" in event and event["group_id"]: if "group_id" in event and event["group_id"]:
abm.group_id = str(event.group_id) abm.group_id = str(event.group_id)
@@ -140,7 +152,7 @@ class AiocqhttpAdapter(Platform):
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE: if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = ( abm.session_id = (
abm.sender.user_id + "_" + str(event.group_id) str(abm.sender.user_id) + "_" + str(event.group_id)
) # 也保留群组 id ) # 也保留群组 id
else: else:
abm.session_id = ( abm.session_id = (
@@ -156,12 +168,20 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event: if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event: if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405 abm.message.append(
Poke(qq=str(event["target_id"]), type="poke")
) # noqa: F405
return abm return abm
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage: async def _convert_handle_message_event(
"""OneBot V11 消息类事件""" self, event: Event, get_reply=True
) -> AstrBotMessage:
"""OneBot V11 消息类事件
@param event: 事件对象
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember( abm.sender = MessageMember(
@@ -197,52 +217,119 @@ class AiocqhttpAdapter(Platform):
return return
# 按消息段类型类型适配 # 按消息段类型类型适配
for m in event.message: for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
t = m["type"]
a = None a = None
if t == "text": if t == "text":
message_str += m["data"]["text"].strip() current_text = "".join(m["data"]["text"] for m in m_group).strip()
a = ComponentTypes[t](**m["data"]) # noqa: F405 message_str += current_text
a = ComponentTypes[t](text=current_text) # noqa: F405
abm.message.append(a) abm.message.append(a)
elif t == "file": elif t == "file":
if m["data"].get("url") and m["data"].get("url").startswith("http"): for m in m_group:
# Lagrange if m["data"].get("url") and m["data"].get("url").startswith("http"):
logger.info("guessing lagrange") # Lagrange
logger.info("guessing lagrange")
file_name = m["data"].get("file_name", "file")
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:
# Napcat
ret = None
if abm.type == MessageType.GROUP_MESSAGE:
ret = await self.bot.call_action(
action="get_group_file_url",
file_id=event.message[0]["data"]["file_id"],
group_id=event.group_id,
)
elif abm.type == MessageType.FRIEND_MESSAGE:
ret = await self.bot.call_action(
action="get_private_file_url",
file_id=event.message[0]["data"]["file_id"],
)
if ret and "url" in ret:
file_url = ret["url"] # https
a = File(name="", url=file_url)
abm.message.append(a)
else:
logger.error(f"获取文件失败: {ret}")
file_name = m["data"].get("file_name", "file") except ActionFailed as e:
path = os.path.join("data/temp", file_name) logger.error(f"获取文件失败: {e},此消息段将被忽略。")
await download_file(m["data"]["url"], path) except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
m["data"] = {"file": path, "name": file_name} elif t == "reply":
a = ComponentTypes[t](**m["data"]) # noqa: F405 for m in m_group:
abm.message.append(a) if not get_reply:
else:
try:
# Napcat, LLBot
ret = await self.bot.call_action(
action="get_file",
file_id=event.message[0]["data"]["file_id"],
)
if not ret.get("file", None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret["file"]):
raise FileNotFoundError(
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
)
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
a = ComponentTypes[t](**m["data"]) # noqa: F405 a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a) abm.message.append(a)
except ActionFailed as e: else:
logger.error(f"获取文件失败: {e},此消息段将被忽略。") try:
except BaseException as e: reply_event_data = await self.bot.call_action(
logger.error(f"获取文件失败: {e},此消息段将被忽略。") action="get_msg",
message_id=int(m["data"]["id"]),
)
abm_reply = await self._convert_handle_message_event(
Event.from_payload(reply_event_data), get_reply=False
)
reply_seg = Reply(
id=abm_reply.message_id,
chain=abm_reply.message,
sender_id=abm_reply.sender.user_id,
sender_nickname=abm_reply.sender.nickname,
time=abm_reply.timestamp,
message_str=abm_reply.message_str,
text=abm_reply.message_str, # for compatibility
qq=abm_reply.sender.user_id, # for compatibility
)
abm.message.append(reply_seg)
except BaseException as e:
logger.error(f"获取引用消息失败: {e}")
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
elif t == "at":
first_at_self_processed = False
for m in m_group:
try:
if m["data"]["qq"] == "all":
abm.message.append(At(qq="all", name="全体成员"))
continue
at_info = await self.bot.call_action(
action="get_stranger_info",
user_id=int(m["data"]["qq"]),
)
if at_info:
nickname = at_info.get("nick", "")
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
abm.message.append(
At(
qq=m["data"]["qq"],
name=nickname,
)
)
if is_at_self and not first_at_self_processed:
# 第一个@是机器人不添加到message_str
first_at_self_processed = True
else:
# 非第一个@机器人或@其他用户添加到message_str
message_str += f" @{nickname} "
else:
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
except ActionFailed as e:
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
else: else:
a = ComponentTypes[t](**m["data"]) # noqa: F405 for m in m_group:
abm.message.append(a) a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.message_str = message_str abm.message_str = message_str
@@ -267,22 +354,19 @@ class AiocqhttpAdapter(Platform):
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler) logging.root.removeHandler(handler)
logging.getLogger("aiocqhttp").setLevel(logging.ERROR) logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
self.shutdown_event = asyncio.Event()
return coro return coro
async def terminate(self): async def terminate(self):
self.stop = True self.shutdown_event.set()
await asyncio.sleep(1)
async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭")
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return self.metadata return self.metadata
async def shutdown_trigger_placeholder(self):
# TODO: use asyncio.Event
while not self._event_queue.closed and not self.stop: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("aiocqhttp 适配器已关闭。")
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
message_event = AiocqhttpMessageEvent( message_event = AiocqhttpMessageEvent(
message_str=message.message_str, message_str=message.message_str,

View File

@@ -0,0 +1,231 @@
import asyncio
import os
import uuid
import aiohttp
import dingtalk_stream
import threading
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .dingtalk_event import DingtalkMessageEvent
from ...register import register_platform_adapter
from astrbot import logger
from dingtalk_stream import AckMessage
from astrbot.core.utils.io import download_file
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class MyEventHandler(dingtalk_stream.EventHandler):
async def process(self, event: dingtalk_stream.EventMessage):
print(
"2",
event.headers.event_type,
event.headers.event_id,
event.headers.event_born_time,
event.data,
)
return AckMessage.STATUS_OK, "OK"
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
class DingtalkPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.unique_session = platform_settings["unique_session"]
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im)
await self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK"
self.client = AstrCallbackClient()
credential = dingtalk_stream.Credential(self.client_id, self.client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential, logger=logger)
client.register_all_event_handler(MyEventHandler())
client.register_callback_handler(
dingtalk_stream.ChatbotMessage.TOPIC, self.client
)
self.client_ = client # 用于 websockets 的 client
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
)
async def convert_msg(
self, message: dingtalk_stream.ChatbotMessage
) -> AstrBotMessage:
abm = AstrBotMessage()
abm.message = []
abm.message_str = ""
abm.timestamp = int(message.create_at / 1000)
abm.type = (
MessageType.GROUP_MESSAGE
if message.conversation_type == "2"
else MessageType.FRIEND_MESSAGE
)
abm.sender = MessageMember(
user_id=message.sender_id, nickname=message.sender_nick
)
abm.self_id = message.chatbot_user_id
abm.message_id = message.message_id
abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE:
if message.is_in_at_list:
abm.message.append(At(qq=abm.self_id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id
else:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
message_type: str = message.message_type
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
contents: list[dict] = rtc.rich_text_list
for content in contents:
plains = ""
if "text" in content:
plains += content["text"]
abm.message.append(Plain(plains))
elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file(
content["downloadCode"],
message.robot_code,
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
case "audio":
pass
return abm # 别忘了返回转换后的消息对象
async def download_ding_file(
self, download_code: str, robot_code: str, ext: str
) -> str:
"""下载钉钉文件
:param access_token: 钉钉机器人的 access_token
:param download_code: 下载码
:param robot_code: 机器人码
:param ext: 文件后缀
:return: 文件路径
"""
access_token = await self.get_access_token()
headers = {
"x-acs-dingtalk-access-token": access_token,
}
payload = {
"downloadCode": download_code,
"robotCode": robot_code,
}
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
headers=headers,
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}"
)
return None
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
return f_path
async def get_access_token(self) -> str:
payload = {
"appKey": self.client_id,
"appSecret": self.client_secret,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}"
)
return None
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
event = DingtalkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
client=self.client,
)
self._event_queue.put_nowait(event)
async def run(self):
# await self.client_.start()
# 钉钉的 SDK 并没有实现真正的异步start() 里面有堵塞方法。
def start_client(loop: asyncio.AbstractEventLoop):
try:
self._shutdown_event = threading.Event()
task = loop.create_task(self.client_.start())
self._shutdown_event.wait()
if task.done():
task.result()
except Exception as e:
if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭")
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self):
def monkey_patch_close():
raise Exception("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
self._shutdown_event.set()
def get_client(self):
return self.client

View File

@@ -0,0 +1,75 @@
import asyncio
import dingtalk_stream
import astrbot.api.message_components as Comp
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot import logger
class DingtalkMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str,
message_obj,
platform_meta,
session_id,
client: dingtalk_stream.ChatbotHandler,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
async def send_with_client(
self, client: dingtalk_stream.ChatbotHandler, message: MessageChain
):
for segment in message.chain:
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"AstrBot",
segment.text,
self.message_obj.raw_message,
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
if segment.file and segment.file.startswith("file:///"):
logger.warning(
"dingtalk only support url image, not: " + segment.file
)
continue
elif segment.file and segment.file.startswith("http"):
markdown_str += f"![image]({segment.file})\n\n"
elif segment.file and segment.file.startswith("base64://"):
logger.warning("dingtalk only support url image, not base64")
continue
else:
logger.warning(
"dingtalk only support url image, not: " + segment.file
)
continue
ret = await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"😄",
markdown_str,
self.message_obj.raw_message,
)
logger.debug(f"send image: {ret}")
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)

View File

@@ -1,17 +1,28 @@
import threading
import asyncio import asyncio
import aiohttp
import quart
import base64 import base64
import datetime import datetime
import re
import os import os
import re
import uuid
import threading
import aiohttp
import anyio import anyio
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType import quart
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api import logger, sp from astrbot.api import logger, sp
from .downloader import GeweDownloader from astrbot.api.message_components import Plain, Image, At, Record, Video
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from .downloader import GeweDownloader
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
try:
from .xml_data_parser import GeweDataParser
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
class SimpleGewechatClient: class SimpleGewechatClient:
@@ -51,11 +62,11 @@ class SimpleGewechatClient:
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.server.add_url_rule( self.server.add_url_rule(
"/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"] "/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
) )
self.server.add_url_rule( self.server.add_url_rule(
"/astrbot-gewechat/file/<file_id>", "/astrbot-gewechat/file/<file_token>",
view_func=self.handle_file, view_func=self._handle_file,
methods=["GET"], methods=["GET"],
) )
@@ -70,9 +81,15 @@ class SimpleGewechatClient:
self.userrealnames = {} self.userrealnames = {}
self.stop = False self.shutdown_event = asyncio.Event()
self.staged_files = {}
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
self.lock = asyncio.Lock()
async def get_token_id(self): async def get_token_id(self):
"""获取 Gewechat Token。"""
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp: async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json() json_blob = await resp.json()
@@ -87,6 +104,15 @@ class SimpleGewechatClient:
type_name = data["type_name"] type_name = data["type_name"]
else: else:
raise Exception("无法识别的消息类型") raise Exception("无法识别的消息类型")
# 以下没有业务处理,只是避免控制台打印太多的日志
if type_name == "ModContacts":
logger.info("gewechat下发ModContacts消息通知。")
return
if type_name == "DelContacts":
logger.info("gewechat下发DelContacts消息通知。")
return
if type_name == "Offline": if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。") logger.critical("收到 gewechat 下线通知。")
return return
@@ -124,18 +150,25 @@ class SimpleGewechatClient:
content = d["Content"]["string"] # 消息内容 content = d["Content"]["string"] # 消息内容
at_me = False at_me = False
at_wxids = []
if "@chatroom" in from_user_name: if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
_t = content.split(":\n") _t = content.split(":\n")
user_id = _t[0] user_id = _t[0]
content = _t[1] content = _t[1]
# at
msg_source = d["MsgSource"]
if "\u2005" in content: if "\u2005" in content:
# at # at
# content = content.split('\u2005')[1] # content = content.split('\u2005')[1]
content = re.sub(r"@[^\u2005]*\u2005", "", content) content = re.sub(r"@[^\u2005]*\u2005", "", content)
at_wxids = re.findall(
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
msg_source,
)
abm.group_id = from_user_name abm.group_id = from_user_name
# at
msg_source = d["MsgSource"]
if ( if (
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
@@ -147,9 +180,13 @@ class SimpleGewechatClient:
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name user_id = from_user_name
# 检查消息是否由自己发送,若是则忽略
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
# if user_id == abm.self_id:
# logger.info("忽略自己发送的消息")
# return None
abm.message = [] abm.message = []
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
# 解析用户真实名字 # 解析用户真实名字
user_real_name = "unknown" user_real_name = "unknown"
@@ -173,11 +210,28 @@ class SimpleGewechatClient:
else: else:
user_real_name = self.userrealnames[abm.group_id][user_id] user_real_name = self.userrealnames[abm.group_id][user_id]
else: else:
user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0] try:
info = (await self.get_user_or_group_info(user_id))["data"][0]
user_real_name = info["nickName"]
except Exception as e:
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
user_real_name = user_id
if at_me:
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
for wxid in at_wxids:
# 群聊里 At 其他人的列表
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
abm.message.append(At(qq=wxid, name=_username))
abm.sender = MessageMember(user_id, user_real_name) abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d abm.raw_message = d
abm.message_str = "" abm.message_str = ""
if user_id == "weixin":
# 忽略微信团队消息
return
# 不同消息类型 # 不同消息类型
match d["MsgType"]: match d["MsgType"]:
case 1: case 1:
@@ -195,18 +249,48 @@ class SimpleGewechatClient:
case 34: case 34:
# 语音消息 # 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if "ImgBuf" in d and "buffer" in d["ImgBuf"]: if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
voice_data = base64.b64decode(d["ImgBuf"]["buffer"]) voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk" temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir, f"gewe_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f: async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_data) await f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path)) abm.message.append(Record(file=file_path, url=file_path))
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
case 37: # 好友申请
logger.info("消息类型(37):好友申请")
case 42: # 名片
logger.info("消息类型(42):名片")
case 43: # 视频
video = Video(file="", cover=content)
abm.message.append(video)
case 47: # emoji
data_parser = GeweDataParser(content, abm.group_id == "")
emoji = data_parser.parse_emoji()
abm.message.append(emoji)
case 48: # 地理位置
logger.info("消息类型(48):地理位置")
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
data_parser = GeweDataParser(content, abm.group_id == "")
segments = data_parser.parse_mutil_49()
if segments:
abm.message.extend(segments)
for seg in segments:
if isinstance(seg, Plain):
abm.message_str += seg.text
case 51: # 帐号消息同步?
logger.info("消息类型(51):帐号消息同步?")
case 10000: # 被踢出群聊/更换群主/修改群名称
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
logger.info(
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
)
case _: case _:
logger.info(f"未实现的消息类型: {d['MsgType']}") logger.info(f"未实现的消息类型: {d['MsgType']}")
abm.raw_message = d abm.raw_message = d
@@ -214,7 +298,7 @@ class SimpleGewechatClient:
logger.debug(f"abm: {abm}") logger.debug(f"abm: {abm}")
return abm return abm
async def callback(self): async def _callback(self):
data = await quart.request.json data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}") logger.debug(f"收到 gewechat 回调: {data}")
@@ -236,9 +320,33 @@ class SimpleGewechatClient:
return quart.jsonify({"r": "AstrBot ACK"}) return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id): async def _register_file(self, file_path: str) -> str:
file_path = f"data/temp/{file_id}" """向 AstrBot 回调服务器 注册一个允许外部访问的文件。
return await quart.send_file(file_path)
Args:
file_path (str): 文件路径。
Returns:
str: 返回一个 auth_token文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
"""
async with self.lock:
if not os.path.exists(file_path):
raise Exception(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
self.staged_files[file_token] = file_path
return file_token
async def _handle_file(self, file_token):
async with self.lock:
if file_token not in self.staged_files:
logger.warning(f"请求的文件 {file_token} 不存在。")
return quart.abort(404)
if not os.path.exists(self.staged_files[file_token]):
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
return quart.abort(404)
file_path = self.staged_files[file_token]
self.staged_files.pop(file_token, None)
return await quart.send_file(file_path)
async def _set_callback_url(self): async def _set_callback_url(self):
logger.info("设置回调,请等待...") logger.info("设置回调,请等待...")
@@ -262,17 +370,14 @@ class SimpleGewechatClient:
await self.server.run_task( await self.server.run_task(
host="0.0.0.0", host="0.0.0.0",
port=self.port, port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder, shutdown_trigger=self.shutdown_trigger,
) )
async def shutdown_trigger_placeholder(self): async def shutdown_trigger(self):
# TODO: use asyncio.Event await self.shutdown_event.wait()
while not self.event_queue.closed and not self.stop: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。")
async def check_online(self, appid: str): async def check_online(self, appid: str):
# /login/checkOnline """检查 APPID 对应的设备是否在线。"""
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
f"{self.base_url}/login/checkOnline", f"{self.base_url}/login/checkOnline",
@@ -283,6 +388,7 @@ class SimpleGewechatClient:
return json_blob["data"] return json_blob["data"]
async def logout(self): async def logout(self):
"""登出 gewechat。"""
if self.appid: if self.appid:
online = await self.check_online(self.appid) online = await self.check_online(self.appid)
if online: if online:
@@ -296,6 +402,7 @@ class SimpleGewechatClient:
logger.info(f"登出结果: {json_blob}") logger.info(f"登出结果: {json_blob}")
async def login(self): async def login(self):
"""登录 gewechat。一般来说插件用不到这个方法。"""
if self.token is None: if self.token is None:
await self.get_token_id() await self.get_token_id()
@@ -304,32 +411,49 @@ class SimpleGewechatClient:
) )
if self.appid: if self.appid:
online = await self.check_online(self.appid) try:
if online: online = await self.check_online(self.appid)
logger.info(f"APPID: {self.appid} 已在线") if online:
return logger.info(f"APPID: {self.appid} 已在线")
return
except Exception as e:
logger.error(f"检查在线状态失败: {e}")
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
payload = {"appId": self.appid} payload = {"appId": self.appid}
if self.appid: if self.appid:
logger.info(f"使用 APPID: {self.appid}, {self.nickname}") logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
async with aiohttp.ClientSession() as session: try:
async with session.post( async with aiohttp.ClientSession() as session:
f"{self.base_url}/login/getLoginQrCode", async with session.post(
headers=self.headers, f"{self.base_url}/login/getLoginQrCode",
json=payload, headers=self.headers,
) as resp: json=payload,
json_blob = await resp.json() ) as resp:
if json_blob["ret"] != 200: json_blob = await resp.json()
raise Exception(f"获取二维码失败: {json_blob}") if json_blob["ret"] != 200:
qr_data = json_blob["data"]["qrData"] error_msg = json_blob.get("data", {}).get("msg", "")
qr_uuid = json_blob["data"]["uuid"] if "设备不存在" in error_msg:
appid = json_blob["data"]["appId"] logger.error(
logger.info(f"APPID: {appid}") f"检测到无效的appid: {self.appid},将清除并重新登录。"
logger.warning( )
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}" sp.put(f"gewechat-appid-{self.nickname}", "")
) self.appid = None
return await self.login()
else:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob["data"]["qrData"]
qr_uuid = json_blob["data"]["uuid"]
appid = json_blob["data"]["appId"]
logger.info(f"APPID: {appid}")
logger.warning(
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
)
except Exception as e:
raise e
# 执行登录 # 执行登录
retry_cnt = 64 retry_cnt = 64
@@ -338,8 +462,10 @@ class SimpleGewechatClient:
retry_cnt -= 1 retry_cnt -= 1
# 需要验证码 # 需要验证码
if os.path.exists("data/temp/gewe_code"): temp_dir = os.path.join(get_astrbot_data_path(), "temp")
with open("data/temp/gewe_code", "r") as f: code_file_path = os.path.join(temp_dir, "gewe_code")
if os.path.exists(code_file_path):
with open(code_file_path, "r") as f:
code = f.read().strip() code = f.read().strip()
if not code: if not code:
logger.warning( logger.warning(
@@ -350,9 +476,9 @@ class SimpleGewechatClient:
payload["captchCode"] = code payload["captchCode"] = code
logger.info(f"使用验证码: {code}") logger.info(f"使用验证码: {code}")
try: try:
os.remove("data/temp/gewe_code") os.remove(code_file_path)
except Exception: except Exception:
logger.warning("删除验证码文件 data/temp/gewe_code 失败。") logger.warning(f"删除验证码文件 {code_file_path} 失败。")
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
@@ -372,17 +498,18 @@ class SimpleGewechatClient:
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456" "此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
) )
else: else:
status = json_blob["data"]["status"] if "status" in json_blob["data"]:
nickname = json_blob["data"].get("nickName", "") status = json_blob["data"]["status"]
if status == 1: nickname = json_blob["data"].get("nickName", "")
logger.info(f"等待确认...{nickname}") if status == 1:
elif status == 2: logger.info(f"等待确认...{nickname}")
logger.info(f"绿泡泡平台登录成功: {nickname}") elif status == 2:
break logger.info(f"绿泡泡平台登录成功: {nickname}")
elif status == 0: break
logger.info("等待扫码...") elif status == 0:
else: logger.info("等待扫码...")
logger.warning(f"未知状态: {status}") else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5) await asyncio.sleep(5)
if appid: if appid:
@@ -390,9 +517,18 @@ class SimpleGewechatClient:
self.appid = appid self.appid = appid
logger.info(f"已保存 APPID: {appid}") logger.info(f"已保存 APPID: {appid}")
"""API""" """API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
"""
async def get_chatroom_member_list(self, chatroom_wxid: str): async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
"""获取群成员列表。
Args:
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
Returns:
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
"""
payload = {"appId": self.appid, "chatroomId": chatroom_wxid} payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -405,6 +541,7 @@ class SimpleGewechatClient:
return json_blob["data"] return json_blob["data"]
async def post_text(self, to_wxid, content: str, ats: str = ""): async def post_text(self, to_wxid, content: str, ats: str = ""):
"""发送纯文本消息"""
payload = { payload = {
"appId": self.appid, "appId": self.appid,
"toWxid": to_wxid, "toWxid": to_wxid,
@@ -421,6 +558,7 @@ class SimpleGewechatClient:
logger.debug(f"发送消息结果: {json_blob}") logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str): async def post_image(self, to_wxid, image_url: str):
"""发送图片消息"""
payload = { payload = {
"appId": self.appid, "appId": self.appid,
"toWxid": to_wxid, "toWxid": to_wxid,
@@ -434,7 +572,79 @@ class SimpleGewechatClient:
json_blob = await resp.json() json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}") logger.debug(f"发送图片结果: {json_blob}")
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
"""发送emoji消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"emojiMd5": emoji_md5,
"emojiSize": emoji_size,
}
# 优先表情包若拿不到表情包的md5就用当作图片发
try:
if emoji_md5 != "" and emoji_size != "":
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postEmoji",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
)
else:
await self.post_image(to_wxid, cdnurl)
except Exception as e:
logger.error(e)
async def post_video(
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"videoUrl": video_url,
"thumbUrl": thumb_url,
"videoDuration": video_duration,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送视频结果: {json_blob}")
async def forward_video(self, to_wxid, cnd_xml: str):
"""转发视频
Args:
to_wxid (str): 发送给谁
cnd_xml (str): 视频消息的cdn信息
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"xml": cnd_xml,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/forwardVideo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"转发视频结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int): async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
"""发送语音信息
Args:
voice_url (str): 语音文件的网络链接
voice_duration (int): 语音时长,毫秒
"""
payload = { payload = {
"appId": self.appid, "appId": self.appid,
"toWxid": to_wxid, "toWxid": to_wxid,
@@ -449,9 +659,16 @@ class SimpleGewechatClient:
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
) as resp: ) as resp:
json_blob = await resp.json() json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}") logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
async def post_file(self, to_wxid, file_url: str, file_name: str): async def post_file(self, to_wxid, file_url: str, file_name: str):
"""发送文件
Args:
to_wxid (string): 微信ID
file_url (str): 文件的网络链接
file_name (str): 文件名
"""
payload = { payload = {
"appId": self.appid, "appId": self.appid,
"toWxid": to_wxid, "toWxid": to_wxid,
@@ -465,3 +682,131 @@ class SimpleGewechatClient:
) as resp: ) as resp:
json_blob = await resp.json() json_blob = await resp.json()
logger.debug(f"发送文件结果: {json_blob}") logger.debug(f"发送文件结果: {json_blob}")
async def add_friend(self, v3: str, v4: str, content: str):
"""申请添加好友"""
payload = {
"appId": self.appid,
"scene": 3,
"content": content,
"v4": v4,
"v3": v3,
"option": 2,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/addContacts",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"申请添加好友结果: {json_blob}")
return json_blob
async def get_group(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_group_member(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def accept_group_invite(self, url: str):
"""同意进群"""
payload = {"appId": self.appid, "url": url}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/agreeJoinRoom",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def add_group_member_to_friend(
self, group_id: str, to_wxid: str, content: str
):
payload = {
"appId": self.appid,
"chatroomId": group_id,
"content": content,
"memberWxid": to_wxid,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/addGroupMemberAsFriend",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_user_or_group_info(self, *ids):
"""
获取用户或群组信息。
:param ids: 可变数量的 wxid 参数
"""
wxids_str = list(ids)
payload = {
"appId": self.appid,
"wxids": wxids_str, # 使用逗号分隔的字符串
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/getDetailInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_contacts_list(self):
"""
获取通讯录列表
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
"""
payload = {"appId": self.appid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/fetchContactsList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取通讯录列表结果: {json_blob}")
return json_blob

View File

@@ -39,3 +39,17 @@ class GeweDownloader:
continue continue
raise Exception("无法下载图片") raise Exception("无法下载图片")
async def download_emoji_md5(self, app_id, emoji_md5):
"""下载emoji"""
try:
payload = {"appId": app_id, "emojiMd5": emoji_md5}
# gewe 计划中的接口暂时没有实现。返回代码404
data = await self._post_json(
self.base_url, "/message/downloadEmojiMd5", payload
)
json_blob = json.loads(data)
return json_blob
except BaseException as e:
logger.error(f"gewe download emoji: {e}")

View File

@@ -1,14 +1,27 @@
import asyncio
import re
import wave import wave
import uuid import uuid
import traceback import traceback
import os import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from typing import AsyncGenerator
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
from astrbot.api.message_components import Plain, Image, Record, At, File from astrbot.api.message_components import (
Plain,
Image,
Record,
At,
File,
Video,
WechatEmoji as Emoji,
)
from .client import SimpleGewechatClient from .client import SimpleGewechatClient
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
def get_wav_duration(file_path): def get_wav_duration(file_path):
@@ -70,39 +83,84 @@ class GewechatPlatformEvent(AstrMessageEvent):
await client.post_text(**payload) await client.post_text(**payload)
elif isinstance(comp, Image): elif isinstance(comp, Image):
img_url = comp.file img_path = await comp.convert_to_file_path()
img_path = "" # 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
if img_url.startswith("file:///"): token = await client._register_file(img_path)
img_path = img_url[8:] img_url = f"{client.file_server_url}/{token}"
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
temp_directory = os.path.abspath("data/temp")
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}") logger.debug(f"gewe callback img url: {img_url}")
await client.post_image(to_wxid, img_url) await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
if comp.cover != "":
await client.forward_video(to_wxid, comp.cover)
else:
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
video_url = comp.file
# 根据 url 下载视频
if video_url.startswith("http"):
video_filename = f"{uuid.uuid4()}.mp4"
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
video_path = os.path.join(temp_dir, video_filename)
await download_file(video_url, video_path)
else:
video_path = video_url
video_token = await client._register_file(video_path)
video_callback_url = f"{client.file_server_url}/{video_token}"
# 获取视频第一帧
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
thumb_path = os.path.join(
temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg"
)
video_path = video_path.replace(" ", "\\ ")
try:
ff = FFmpeg()
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
ff.options(command)
thumb_token = await client._register_file(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_token}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
# 发送视频
await client.post_video(
to_wxid, video_callback_url, thumb_url, video_duration
)
# 删除临时缩略图文件
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record): elif isinstance(comp, Record):
# 默认已经存在 data/temp 中 # 默认已经存在 data/temp 中
record_url = comp.file record_url = comp.file
record_path = "" record_path = await comp.convert_to_file_path()
if record_url.startswith("file:///"): temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_path = record_url[8:] silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
try: try:
duration = await wav_to_tencent_silk(record_path, silk_path) duration = await wav_to_tencent_silk(record_path, silk_path)
except Exception as e: except Exception as e:
@@ -111,8 +169,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
logger.info("Silk 语音文件格式转换至: " + record_path) logger.info("Silk 语音文件格式转换至: " + record_path)
if duration == 0: if duration == 0:
duration = get_wav_duration(record_path) duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path) token = await client._register_file(silk_path)
record_url = f"{client.file_server_url}/{file_id}" record_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback record url: {record_url}") logger.debug(f"gewe callback record url: {record_url}")
await client.post_voice(to_wxid, record_url, duration * 1000) await client.post_voice(to_wxid, record_url, duration * 1000)
elif isinstance(comp, File): elif isinstance(comp, File):
@@ -121,14 +179,19 @@ class GewechatPlatformEvent(AstrMessageEvent):
if file_path.startswith("file:///"): if file_path.startswith("file:///"):
file_path = file_path[8:] file_path = file_path[8:]
elif file_path.startswith("http"): elif file_path.startswith("http"):
await download_file(file_path, f"data/temp/{file_name}") temp_dir = os.path.join(get_astrbot_data_path(), "temp")
temp_file_path = os.path.join(temp_dir, file_name)
await download_file(file_path, temp_file_path)
file_path = temp_file_path
else: else:
file_path = file_path file_path = file_path
file_id = os.path.basename(file_path) token = await client._register_file(file_path)
file_url = f"{client.file_server_url}/{file_id}" file_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback file url: {file_url}") logger.debug(f"gewe callback file url: {file_url}")
await client.post_file(to_wxid, file_url, file_id) await client.post_file(to_wxid, file_url, file_name)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At): elif isinstance(comp, At):
pass pass
else: else:
@@ -138,3 +201,64 @@ class GewechatPlatformEvent(AstrMessageEvent):
to_wxid = self.message_obj.raw_message.get("to_wxid", None) to_wxid = self.message_obj.raw_message.get("to_wxid", None)
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client) await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
await super().send(message) await super().send(message)
async def get_group(self, group_id=None, **kwargs):
# 确定有效的 group_id
if group_id is None:
group_id = self.get_group_id()
if not group_id:
return None
res = await self.client.get_group(group_id)
data: dict = res["data"]
if not data["chatroomId"]:
return None
members = [
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
for member in data.get("memberList", [])
]
return Group(
group_id=data["chatroomId"],
group_name=data.get("nickName"),
group_avatar=data.get("smallHeadImgUrl"),
group_owner=data.get("chatRoomOwner"),
members=members,
)
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5) # 限速
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)

View File

@@ -8,6 +8,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient from .client import SimpleGewechatClient
from astrbot import logger
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from typing import override from typing import override
@@ -59,13 +60,18 @@ class GewechatPlatformAdapter(Platform):
@override @override
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
"gewechat", name="gewechat",
"基于 gewechat 的 Wechat 适配器", description="基于 gewechat 的 Wechat 适配器",
id=self.config.get("id"),
) )
async def terminate(self): async def terminate(self):
self.client.stop = True self.client.shutdown_event.set()
await asyncio.sleep(1) try:
await self.client.server.shutdown()
except Exception as _:
pass
logger.info("Gewechat 适配器已被优雅地关闭。")
async def logout(self): async def logout(self):
await self.client.logout() await self.client.logout()

View File

@@ -0,0 +1,110 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import (
WechatEmoji as Emoji,
Reply,
Plain,
BaseMessageComponent,
)
class GeweDataParser:
def __init__(self, data, is_private_chat):
self.data = data
self.is_private_chat = is_private_chat
def _format_to_xml(self):
return eT.fromstring(self.data)
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
appmsg_type = self._format_to_xml().find(".//appmsg/type")
if appmsg_type is None:
return
match appmsg_type.text:
case "57":
return self.parse_reply()
def parse_emoji(self) -> Emoji | None:
try:
emoji_element = self._format_to_xml().find(".//emoji")
# 提取 md5 和 len 属性
if emoji_element is not None:
md5_value = emoji_element.get("md5")
emoji_size = emoji_element.get("len")
cdnurl = emoji_element.get("cdnurl")
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
except Exception as e:
logger.error(f"gewechat: parse_emoji failed, {e}")
def parse_reply(self) -> list[Reply, Plain] | None:
"""解析引用消息
Returns:
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
"""
try:
replied_id = -1
replied_uid = 0
replied_nickname = ""
replied_content = "" # 被引用者说的内容
content = "" # 引用者说的内容
root = self._format_to_xml()
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 被引用的信息
svrid = refermsg.find("svrid")
fromusr = refermsg.find("fromusr")
displayname = refermsg.find("displayname")
refermsg_content = refermsg.find("content")
if svrid is not None:
replied_id = svrid.text
if fromusr is not None:
replied_uid = fromusr.text
if displayname is not None:
replied_nickname = displayname.text
if refermsg_content is not None:
# 处理引用嵌套,包括嵌套公众号消息
if refermsg_content.text.startswith(
"<msg>"
) or refermsg_content.text.startswith("<?xml"):
try:
logger.debug("gewechat: Reference message is nested")
refer_root = eT.fromstring(refermsg_content.text)
img = refer_root.find("img")
if img is not None:
replied_content = "[图片]"
else:
app_msg = refer_root.find("appmsg")
refermsg_content_title = app_msg.find("title")
logger.debug(
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
)
replied_content = refermsg_content_title.text
except Exception as e:
logger.error(f"gewechat: nested failed, {e}")
# 处理异常情况
replied_content = refermsg_content.text
else:
replied_content = refermsg_content.text
# 提取引用者说的内容
title = root.find(".//appmsg/title")
if title is not None:
content = title.text
reply_seg = Reply(
id=replied_id,
chain=[Plain(replied_content)],
sender_id=replied_uid,
sender_nickname=replied_nickname,
message_str=replied_content,
)
plain_seg = Plain(content)
return [reply_seg, plain_seg]
except Exception as e:
logger.error(f"gewechat: parse_reply failed, {e}")

View File

@@ -2,6 +2,8 @@ import base64
import asyncio import asyncio
import json import json
import re import re
import uuid
import astrbot.api.message_components as Comp
from astrbot.api.platform import ( from astrbot.api.platform import (
Platform, Platform,
@@ -11,7 +13,6 @@ from astrbot.api.platform import (
PlatformMetadata, PlatformMetadata,
) )
from astrbot.api.event import MessageChain from astrbot.api.event import MessageChain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from .lark_event import LarkMessageEvent from .lark_event import LarkMessageEvent
from ...register import register_platform_adapter from ...register import register_platform_adapter
@@ -66,12 +67,47 @@ class LarkPlatformAdapter(Platform):
async def send_by_session( async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain self, session: MessageSesion, message_chain: MessageChain
): ):
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = {
"zh_cn": {
"title": "",
"content": res,
}
}
if session.message_type == MessageType.GROUP_MESSAGE:
id_type = "chat_id"
if "%" in session.session_id:
session.session_id = session.session_id.split("%")[1]
else:
id_type = "open_id"
request = (
CreateMessageRequest.builder()
.receive_id_type(id_type)
.request_body(
CreateMessageRequestBody.builder()
.receive_id(session.session_id)
.content(json.dumps(wrapped))
.msg_type("post")
.uuid(str(uuid.uuid4()))
.build()
)
.build()
)
response = await self.lark_api.im.v1.message.acreate(request)
if not response.success():
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
"lark", name="lark",
"飞书机器人官方 API 适配器", description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
) )
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
@@ -92,7 +128,7 @@ class LarkPlatformAdapter(Platform):
at_list = {} at_list = {}
if message.mentions: if message.mentions:
for m in message.mentions: for m in message.mentions:
at_list[m.key] = At(qq=m.id.open_id, name=m.name) at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
if m.name == self.bot_name: if m.name == self.bot_name:
abm.self_id = m.id.open_id abm.self_id = m.id.open_id
@@ -111,7 +147,7 @@ class LarkPlatformAdapter(Platform):
if s in at_list: if s in at_list:
abm.message.append(at_list[s]) abm.message.append(at_list[s])
else: else:
abm.message.append(Plain(parts[i].strip())) abm.message.append(Comp.Plain(parts[i].strip()))
elif message.message_type == "post": elif message.message_type == "post":
_ls = [] _ls = []
@@ -132,7 +168,7 @@ class LarkPlatformAdapter(Platform):
if comp["tag"] == "at": if comp["tag"] == "at":
abm.message.append(at_list[comp["user_id"]]) abm.message.append(at_list[comp["user_id"]])
elif comp["tag"] == "text" and comp["text"].strip(): elif comp["tag"] == "text" and comp["text"].strip():
abm.message.append(Plain(comp["text"].strip())) abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img": elif comp["tag"] == "img":
image_key = comp["image_key"] image_key = comp["image_key"]
request = ( request = (
@@ -147,10 +183,10 @@ class LarkPlatformAdapter(Platform):
logger.error(f"无法下载飞书图片: {image_key}") logger.error(f"无法下载飞书图片: {image_key}")
image_bytes = response.file.read() image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode() image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Image.fromBase64(image_base64)) abm.message.append(Comp.Image.fromBase64(image_base64))
for comp in abm.message: for comp in abm.message:
if isinstance(comp, Plain): if isinstance(comp, Comp.Plain):
abm.message_str += comp.text abm.message_str += comp.text
abm.message_id = message.message_id abm.message_id = message.message_id
abm.raw_message = message abm.raw_message = message
@@ -165,7 +201,10 @@ class LarkPlatformAdapter(Platform):
else: else:
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
else: else:
abm.session_id = abm.sender.user_id if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
else:
abm.session_id = abm.sender.user_id
logger.debug(abm) logger.debug(abm)
await self.handle_msg(abm) await self.handle_msg(abm)
@@ -185,5 +224,9 @@ class LarkPlatformAdapter(Platform):
# self.client.start() # self.client.start()
await self.client._connect() await self.client._connect()
async def terminate(self):
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已被优雅地关闭")
def get_client(self) -> lark.Client: def get_client(self) -> lark.Client:
return self.client return self.client

View File

@@ -1,12 +1,16 @@
import json import json
import os
import uuid import uuid
import base64
import lark_oapi as lark import lark_oapi as lark
from io import BytesIO
from typing import List from typing import List
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image as AstrBotImage, At from astrbot.api.message_components import Plain, Image as AstrBotImage, At
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import *
from astrbot import logger from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class LarkMessageEvent(AstrMessageEvent): class LarkMessageEvent(AstrMessageEvent):
@@ -27,22 +31,33 @@ class LarkMessageEvent(AstrMessageEvent):
_stage.append({"tag": "at", "user_id": comp.qq, "style": []}) _stage.append({"tag": "at", "user_id": comp.qq, "style": []})
elif isinstance(comp, AstrBotImage): elif isinstance(comp, AstrBotImage):
file_path = "" file_path = ""
image_file = None
if comp.file and comp.file.startswith("file:///"): if comp.file and comp.file.startswith("file:///"):
file_path = comp.file.replace("file:///", "") file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"): elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file) image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path file_path = image_file_path
elif comp.file and comp.file.startswith("base64://"): elif comp.file and comp.file.startswith("base64://"):
pass base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str)
# save as temp file
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else: else:
file_path = comp.file file_path = comp.file
if image_file is None:
image_file = open(file_path, "rb")
request = ( request = (
CreateImageRequest.builder() CreateImageRequest.builder()
.request_body( .request_body(
CreateImageRequestBody.builder() CreateImageRequestBody.builder()
.image_type("message") .image_type("message")
.image(open(file_path, "rb")) .image(image_file)
.build() .build()
) )
.build() .build()
@@ -51,7 +66,7 @@ class LarkMessageEvent(AstrMessageEvent):
if not response.success(): if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}") logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
image_key = response.data.image_key image_key = response.data.image_key
print(image_key) logger.debug(image_key)
ret.append(_stage) ret.append(_stage)
ret.append([{"tag": "img", "image_key": image_key}]) ret.append([{"tag": "img", "image_key": image_key}])
_stage.clear() _stage.clear()
@@ -91,3 +106,16 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error(f"回复飞书消息失败({response.code}): {response.msg}") logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
await super().send(message) await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)

View File

@@ -2,6 +2,7 @@ import botpy
import botpy.message import botpy.message
import botpy.types import botpy.types
import botpy.types.message import botpy.types.message
import asyncio
from astrbot.core.utils.io import file_to_base64, download_image_by_url from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.platform import AstrBotMessage, PlatformMetadata
@@ -9,6 +10,8 @@ from astrbot.api.message_components import Plain, Image
from botpy import Client from botpy import Client
from botpy.http import Route from botpy.http import Route
from astrbot.api import logger from astrbot.api import logger
from botpy.types import message
import random
class QQOfficialMessageEvent(AstrMessageEvent): class QQOfficialMessageEvent(AstrMessageEvent):
@@ -30,8 +33,45 @@ class QQOfficialMessageEvent(AstrMessageEvent):
else: else:
self.send_buffer.chain.extend(message.chain) self.send_buffer.chain.extend(message.chain)
async def _post_send(self): async def send_streaming(self, generator, use_fallback: bool = False):
"""QQ 官方 API 仅支持回复一次""" """流式输出仅支持消息列表私聊"""
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
try:
async for chain in generator:
source = self.message_obj.raw_message
if not self.send_buffer:
self.send_buffer = chain
else:
self.send_buffer.chain.extend(chain.chain)
if isinstance(source, botpy.message.C2CMessage):
# 真流式传输
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,并且传输 buffer 中剩余的消息
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
except Exception as e:
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
self.send_buffer = None
return await super().send_streaming(generator, use_fallback)
async def _post_send(self, stream: dict = None):
if not self.send_buffer:
return
source = self.message_obj.raw_message source = self.message_obj.raw_message
assert isinstance( assert isinstance(
source, source,
@@ -57,6 +97,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"msg_id": self.message_obj.message_id, "msg_id": self.message_obj.message_id,
} }
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
payload["msg_seq"] = random.randint(1, 10000)
match type(source): match type(source):
case botpy.message.GroupMessage: case botpy.message.GroupMessage:
if image_base64: if image_base64:
@@ -65,7 +108,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
) )
payload["media"] = media payload["media"] = media
payload["msg_type"] = 7 payload["msg_type"] = 7
await self.bot.api.post_group_message( ret = await self.bot.api.post_group_message(
group_openid=source.group_openid, **payload group_openid=source.group_openid, **payload
) )
case botpy.message.C2CMessage: case botpy.message.C2CMessage:
@@ -75,22 +118,34 @@ class QQOfficialMessageEvent(AstrMessageEvent):
) )
payload["media"] = media payload["media"] = media
payload["msg_type"] = 7 payload["msg_type"] = 7
await self.bot.api.post_c2c_message( if stream:
openid=source.author.user_openid, **payload ret = await self.post_c2c_message(
) openid=source.author.user_openid,
**payload,
stream=stream,
)
else:
ret = await self.post_c2c_message(
openid=source.author.user_openid, **payload
)
logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message: case botpy.message.Message:
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
await self.bot.api.post_message(channel_id=source.channel_id, **payload) ret = await self.bot.api.post_message(
channel_id=source.channel_id, **payload
)
case botpy.message.DirectMessage: case botpy.message.DirectMessage:
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload) ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(self.send_buffer) await super().send(self.send_buffer)
self.send_buffer = None self.send_buffer = None
return ret
async def upload_group_and_c2c_image( async def upload_group_and_c2c_image(
self, image_base64: str, file_type: int, **kwargs self, image_base64: str, file_type: int, **kwargs
) -> botpy.types.message.Media: ) -> botpy.types.message.Media:
@@ -112,6 +167,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
) )
return await self.bot.api._http.request(route, json=payload) return await self.bot.api._http.request(route, json=payload)
async def post_c2c_message(
self,
openid: str,
msg_type: int = 0,
content: str = None,
embed: message.Embed = None,
ark: message.Ark = None,
message_reference: message.Reference = None,
media: message.Media = None,
msg_id: str = None,
msg_seq: str = 1,
event_id: str = None,
markdown: message.MarkdownPayload = None,
keyboard: message.Keyboard = None,
stream: dict = None,
) -> message.Message:
payload = locals()
payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload)
@staticmethod @staticmethod
async def _parse_to_qqofficial(message: MessageChain): async def _parse_to_qqofficial(message: MessageChain):
plain_text = "" plain_text = ""
@@ -122,16 +198,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text plain_text += i.text
elif isinstance(i, Image) and not image_base64: elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"): if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "") image_base64 = file_to_base64(i.file[8:])
image_file_path = i.file[8:] image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"): elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file) image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path).replace( image_base64 = file_to_base64(image_file_path)
"base64://", "" elif i.file and i.file.startswith("base64://"):
) image_base64 = i.file
else: else:
image_base64 = file_to_base64(i.file).replace("base64://", "") image_base64 = file_to_base64(i.file)
image_file_path = i.file image_base64 = image_base64.removeprefix("base64://")
else: else:
logger.debug(f"qq_official 忽略 {i.type}") logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path return plain_text, image_base64, image_file_path

View File

@@ -17,6 +17,7 @@ from astrbot.api.platform import (
MessageType, MessageType,
PlatformMetadata, PlatformMetadata,
) )
from astrbot import logger
from astrbot.api.event import MessageChain from astrbot.api.event import MessageChain
from typing import Union, List from typing import Union, List
from astrbot.api.message_components import Image, Plain, At from astrbot.api.message_components import Image, Plain, At
@@ -125,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
"qq_official", name="qq_official",
"QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
) )
@staticmethod @staticmethod
@@ -204,3 +206,7 @@ class QQOfficialPlatformAdapter(Platform):
def get_client(self) -> botClient: def get_client(self) -> botClient:
return self.client return self.client
async def terminate(self):
await self.client.close()
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")

View File

@@ -13,6 +13,7 @@ from .qo_webhook_event import QQOfficialWebhookMessageEvent
from ...register import register_platform_adapter from ...register import register_platform_adapter
from .qo_webhook_server import QQOfficialWebhook from .qo_webhook_server import QQOfficialWebhook
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
from astrbot import logger
# remove logger handler # remove logger handler
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
@@ -98,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
"qq_official_webhook", name="qq_official_webhook",
"QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
) )
async def run(self): async def run(self):
@@ -111,3 +113,12 @@ class QQOfficialWebhookPlatformAdapter(Platform):
def get_client(self) -> botClient: def get_client(self) -> botClient:
return self.client return self.client
async def terminate(self):
self.webhook_helper.shutdown_event.set()
await self.client.close()
try:
await self.webhook_helper.server.shutdown()
except Exception as _:
pass
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")

View File

@@ -15,6 +15,7 @@ class QQOfficialWebhook:
self.appid = config["appid"] self.appid = config["appid"]
self.secret = config["secret"] self.secret = config["secret"]
self.port = config.get("port", 6196) self.port = config.get("port", 6196)
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
if isinstance(self.port, str): if isinstance(self.port, str):
self.port = int(self.port) self.port = int(self.port)
@@ -29,6 +30,7 @@ class QQOfficialWebhook:
) )
self.client = botpy_client self.client = botpy_client
self.event_queue = event_queue self.event_queue = event_queue
self.shutdown_event = asyncio.Event()
async def initialize(self): async def initialize(self):
logger.info("正在登录到 QQ 官方机器人...") logger.info("正在登录到 QQ 官方机器人...")
@@ -95,13 +97,14 @@ class QQOfficialWebhook:
return {"opcode": 12} return {"opcode": 12}
async def start_polling(self): async def start_polling(self):
logger.info(
f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。"
)
await self.server.run_task( await self.server.run_task(
host="0.0.0.0", host=self.callback_server_host,
port=self.port, port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder, shutdown_trigger=self.shutdown_trigger,
) )
async def shutdown_trigger_placeholder(self): async def shutdown_trigger(self):
while not self.event_queue.closed: # noqa: ASYNC110 await self.shutdown_event.wait()
await asyncio.sleep(1)
logger.info("qq_official_webhook 适配器已关闭。")

View File

@@ -1,33 +1,32 @@
import asyncio
import re
import sys import sys
import uuid import uuid
import asyncio
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from telegram import BotCommand, Update
from telegram.constants import ChatType
from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters
from telegram.ext import MessageHandler as TelegramMessageHandler
import astrbot.api.message_components as Comp
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import ( from astrbot.api.platform import (
Platform,
AstrBotMessage, AstrBotMessage,
MessageMember, MessageMember,
PlatformMetadata,
MessageType, MessageType,
) Platform,
from astrbot.api.event import MessageChain PlatformMetadata,
from astrbot.api.message_components import ( register_platform_adapter,
Plain,
Image,
Record,
File as AstrBotFile,
Video,
At,
) )
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import star_handlers_registry
from telegram import Update
from telegram.ext import ApplicationBuilder, ContextTypes, filters
from telegram.constants import ChatType
from telegram.ext import MessageHandler as TelegramMessageHandler
from .tg_event import TelegramPlatformEvent from .tg_event import TelegramPlatformEvent
from astrbot.api import logger
from telegram.ext import ExtBot
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from typing import override from typing import override
@@ -59,6 +58,14 @@ class TelegramPlatformAdapter(Platform):
self.base_url = base_url self.base_url = base_url
self.enable_command_register = self.config.get(
"telegram_command_register", True
)
self.enable_command_refresh = self.config.get(
"telegram_command_auto_refresh", True
)
self.last_command_hash = None
self.application = ( self.application = (
ApplicationBuilder() ApplicationBuilder()
.token(self.config["telegram_token"]) .token(self.config["telegram_token"])
@@ -68,12 +75,14 @@ class TelegramPlatformAdapter(Platform):
) )
message_handler = TelegramMessageHandler( message_handler = TelegramMessageHandler(
filters=filters.ALL, # receive all messages filters=filters.ALL, # receive all messages
callback=self.convert_message, callback=self.message_handler,
) )
self.application.add_handler(message_handler) self.application.add_handler(message_handler)
self.client = self.application.bot self.client = self.application.bot
logger.debug(f"Telegram base url: {self.client.base_url}") logger.debug(f"Telegram base url: {self.client.base_url}")
self.scheduler = AsyncIOScheduler()
@override @override
async def send_by_session( async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain self, session: MessageSesion, message_chain: MessageChain
@@ -87,94 +96,250 @@ class TelegramPlatformAdapter(Platform):
@override @override
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return PlatformMetadata( return PlatformMetadata(
"telegram", name="telegram", description="telegram 适配器", id=self.config.get("id")
"telegram 适配器",
) )
@override @override
async def run(self): async def run(self):
await self.application.initialize() await self.application.initialize()
await self.application.start() await self.application.start()
if self.enable_command_register:
await self.register_commands()
if self.enable_command_refresh and self.enable_command_register:
self.scheduler.add_job(
self.register_commands,
"interval",
seconds=self.config.get("telegram_command_register_interval", 300),
id="telegram_command_register",
misfire_grace_time=60,
)
self.scheduler.start()
queue = self.application.updater.start_polling() queue = self.application.updater.start_polling()
logger.info("Telegram Platform Adapter is running.") logger.info("Telegram Platform Adapter is running.")
await queue await queue
async def register_commands(self):
"""收集所有注册的指令并注册到 Telegram"""
try:
commands = self.collect_commands()
if commands:
current_hash = hash(
tuple((cmd.command, cmd.description) for cmd in commands)
)
if current_hash == self.last_command_hash:
return
self.last_command_hash = current_hash
await self.client.delete_my_commands()
await self.client.set_my_commands(commands)
except Exception as e:
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
def collect_commands(self) -> list[BotCommand]:
"""从注册的处理器中收集所有指令"""
command_dict = {}
skip_commands = {"start"}
for handler_md in star_handlers_registry:
handler_metadata = handler_md
if not star_map[handler_metadata.handler_module_path].activated:
continue
for event_filter in handler_metadata.event_filters:
cmd_info = self._extract_command_info(
event_filter, handler_metadata, skip_commands
)
if cmd_info:
cmd_name, description = cmd_info
command_dict.setdefault(cmd_name, description)
commands_a = sorted(command_dict.keys())
return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a]
@staticmethod
def _extract_command_info(
event_filter, handler_metadata, skip_commands: set
) -> tuple[str, str] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
is_group = False
if isinstance(event_filter, CommandFilter) and event_filter.command_name:
if (
event_filter.parent_command_names
and event_filter.parent_command_names != [""]
):
return None
cmd_name = event_filter.command_name
elif isinstance(event_filter, CommandGroupFilter):
if event_filter.parent_group:
return None
cmd_name = event_filter.group_name
is_group = True
if not cmd_name or cmd_name in skip_commands:
return None
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
logger.debug(f"跳过无法注册的命令: {cmd_name}")
return None
# Build description.
description = handler_metadata.desc or (
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
)
if len(description) > 30:
description = description[:30] + "..."
return cmd_name, description
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
await context.bot.send_message( await context.bot.send_message(
chat_id=update.effective_chat.id, text=self.config["start_message"] chat_id=update.effective_chat.id, text=self.config["start_message"]
) )
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
logger.debug(f"Telegram message: {update.message}")
abm = await self.convert_message(update, context)
if abm:
await self.handle_msg(abm)
async def convert_message( async def convert_message(
self, update: Update, context: ContextTypes.DEFAULT_TYPE self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
) -> AstrBotMessage: ) -> AstrBotMessage:
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
@param update: Telegram 的 Update 对象。
@param context: Telegram 的 Context 对象。
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
message = AstrBotMessage() message = AstrBotMessage()
message.session_id = str(update.message.chat.id)
# 获得是群聊还是私聊 # 获得是群聊还是私聊
if update.effective_chat.type == ChatType.PRIVATE: if update.message.chat.type == ChatType.PRIVATE:
message.type = MessageType.FRIEND_MESSAGE message.type = MessageType.FRIEND_MESSAGE
else: else:
message.type = MessageType.GROUP_MESSAGE message.type = MessageType.GROUP_MESSAGE
message.group_id = update.effective_chat.id message.group_id = str(update.message.chat.id)
if update.message.message_thread_id:
# Topic Group
message.group_id += "#" + str(update.message.message_thread_id)
message.session_id = message.group_id
message.message_id = str(update.message.message_id) message.message_id = str(update.message.message_id)
message.session_id = str(update.effective_chat.id)
message.sender = MessageMember( message.sender = MessageMember(
str(update.effective_user.id), update.effective_user.username str(update.message.from_user.id), update.message.from_user.username
) )
message.self_id = str(context.bot.username) message.self_id = str(context.bot.username)
message.raw_message = update message.raw_message = update
message.message_str = "" message.message_str = ""
message.message = [] message.message = []
logger.debug(f"Telegram message: {update.message}") if update.message.reply_to_message and not (
update.message.is_topic_message
and update.message.message_thread_id
== update.message.reply_to_message.message_id
):
# 获取回复消息
reply_update = Update(
update_id=1,
message=update.message.reply_to_message,
)
reply_abm = await self.convert_message(reply_update, context, False)
message.message.append(
Comp.Reply(
id=reply_abm.message_id,
chain=reply_abm.message,
sender_id=reply_abm.sender.user_id,
sender_nickname=reply_abm.sender.nickname,
time=reply_abm.timestamp,
message_str=reply_abm.message_str,
text=reply_abm.message_str,
qq=reply_abm.sender.user_id,
)
)
if update.message.text: if update.message.text:
# 处理文本消息
plain_text = update.message.text plain_text = update.message.text
# 群聊场景命令特殊处理
if plain_text.startswith("/"):
command_parts = plain_text.split(" ", 1)
if "@" in command_parts[0]:
command, bot_name = command_parts[0].split("@")
if bot_name == self.client.username:
plain_text = command + (
f" {command_parts[1]}" if len(command_parts) > 1 else ""
)
if update.message.entities: if update.message.entities:
for entity in update.message.entities: for entity in update.message.entities:
if entity.type == "mention": if entity.type == "mention":
name = plain_text[ name = plain_text[
entity.offset + 1 : entity.offset + entity.length entity.offset + 1 : entity.offset + entity.length
] ]
message.message.append(At(qq=name, name=name)) message.message.append(Comp.At(qq=name, name=name))
plain_text = ( # 如果mention是当前bot则移除否则保留
plain_text[: entity.offset] if name.lower() == context.bot.username.lower():
+ plain_text[entity.offset + entity.length :] plain_text = (
) plain_text[: entity.offset]
+ plain_text[entity.offset + entity.length :]
)
if plain_text: if plain_text:
message.message.append(Plain(plain_text)) message.message.append(Comp.Plain(plain_text))
message.message_str = plain_text message.message_str = plain_text
if message.message_str == "/start": if message.message_str.strip() == "/start":
await self.start(update, context) await self.start(update, context)
return return
elif update.message.voice: elif update.message.voice:
file = await update.message.voice.get_file() file = await update.message.voice.get_file()
message.message = [ message.message = [
Record(file=file.file_path, url=file.file_path), Comp.Record(file=file.file_path, url=file.file_path),
] ]
elif update.message.photo: elif update.message.photo:
photo = update.message.photo[-1] # get the largest photo photo = update.message.photo[-1] # get the largest photo
file = await photo.get_file() file = await photo.get_file()
message.message.append(Image(file=file.file_path, url=file.file_path)) message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
if update.message.caption:
message.message_str = update.message.caption
message.message.append(Comp.Plain(message.message_str))
if update.message.caption_entities:
for entity in update.message.caption_entities:
if entity.type == "mention":
name = message.message_str[
entity.offset + 1 : entity.offset + entity.length
]
message.message.append(Comp.At(qq=name, name=name))
elif update.message.sticker:
# 将sticker当作图片处理
file = await update.message.sticker.get_file()
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
if update.message.sticker.emoji:
sticker_text = f"Sticker: {update.message.sticker.emoji}"
message.message_str = sticker_text
message.message.append(Comp.Plain(sticker_text))
elif update.message.document: elif update.message.document:
file = await update.message.document.get_file() file = await update.message.document.get_file()
message.message = [ message.message = [
AstrBotFile( Comp.File(file=file.file_path, name=update.message.document.file_name),
file=file.file_path, name=update.message.document.file_name
),
] ]
elif update.message.video: elif update.message.video:
file = await update.message.video.get_file() file = await update.message.video.get_file()
message.message = [ message.message = [
Video(file=file.file_path, path=file.file_path), Comp.Video(file=file.file_path, path=file.file_path),
] ]
await self.handle_msg(message) return message
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
message_event = TelegramPlatformEvent( message_event = TelegramPlatformEvent(
@@ -188,3 +353,21 @@ class TelegramPlatformAdapter(Platform):
def get_client(self) -> ExtBot: def get_client(self) -> ExtBot:
return self.client return self.client
async def terminate(self):
try:
if self.scheduler.running:
self.scheduler.shutdown()
await self.application.stop()
if self.enable_command_register:
await self.client.delete_my_commands()
# 保险起见先判断是否存在updater对象
if self.application.updater is not None:
await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭")
except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}")

View File

@@ -1,10 +1,34 @@
import os
import re
import asyncio
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record from astrbot.api.message_components import (
Plain,
Image,
Reply,
At,
File,
Record,
)
from telegram.ext import ExtBot from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class TelegramPlatformEvent(AstrMessageEvent): class TelegramPlatformEvent(AstrMessageEvent):
# Telegram 的最大消息长度限制
MAX_MESSAGE_LENGTH = 4096
SPLIT_PATTERNS = {
"paragraph": re.compile(r"\n\n"),
"line": re.compile(r"\n"),
"sentence": re.compile(r"[.!?。!?]"),
"word": re.compile(r"\s"),
}
def __init__( def __init__(
self, self,
message_str: str, message_str: str,
@@ -16,8 +40,33 @@ class TelegramPlatformEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id) super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client self.client = client
@staticmethod def _split_message(self, text: str) -> list[str]:
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str): if len(text) <= self.MAX_MESSAGE_LENGTH:
return [text]
chunks = []
while text:
if len(text) <= self.MAX_MESSAGE_LENGTH:
chunks.append(text)
break
split_point = self.MAX_MESSAGE_LENGTH
segment = text[: self.MAX_MESSAGE_LENGTH]
for _, pattern in self.SPLIT_PATTERNS.items():
if matches := list(pattern.finditer(segment)):
last_match = matches[-1]
split_point = last_match.end()
break
chunks.append(text[:split_point])
text = text[split_point:].lstrip()
return chunks
async def send_with_client(
self, client: ExtBot, message: MessageChain, user_name: str
):
image_path = None image_path = None
has_reply = False has_reply = False
@@ -31,36 +80,51 @@ class TelegramPlatformEvent(AstrMessageEvent):
at_user_id = i.name at_user_id = i.name
at_flag = False at_flag = False
message_thread_id = None
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
for i in message.chain: for i in message.chain:
payload = { payload = {
"chat_id": user_name, "chat_id": user_name,
} }
if has_reply: if has_reply:
payload["reply_to_message_id"] = reply_message_id payload["reply_to_message_id"] = reply_message_id
if message_thread_id:
payload["message_thread_id"] = message_thread_id
if isinstance(i, Plain): if isinstance(i, Plain):
if at_user_id and not at_flag: if at_user_id and not at_flag:
i.text = f"@{at_user_id} " + i.text i.text = f"@{at_user_id} {i.text}"
at_flag = True at_flag = True
await client.send_message(text=i.text, **payload) chunks = self._split_message(i.text)
for chunk in chunks:
try:
md_text = telegramify_markdown.markdownify(
chunk, max_line_length=None, normalize_whitespace=False
)
await client.send_message(
text=md_text, parse_mode="MarkdownV2", **payload
)
except Exception as e:
logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead."
)
await client.send_message(text=chunk, **payload)
elif isinstance(i, Image): elif isinstance(i, Image):
if i.path: image_path = await i.convert_to_file_path()
image_path = i.path await client.send_photo(photo=image_path, **payload)
else:
image_path = i.file
if image_path.startswith("base64://"):
import base64
base64_data = image_path[9:]
image_bytes = base64.b64decode(base64_data)
await client.send_photo(photo=image_bytes, **payload)
else:
await client.send_photo(photo=image_path, **payload)
elif isinstance(i, File): elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await client.send_document(document=i.file, filename=i.name, **payload) await client.send_document(document=i.file, filename=i.name, **payload)
elif isinstance(i, Record): elif isinstance(i, Record):
await client.send_voice(voice=i.file, **payload) path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload)
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE: if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -68,3 +132,110 @@ class TelegramPlatformEvent(AstrMessageEvent):
else: else:
await self.send_with_client(self.client, message, self.get_sender_id()) await self.send_with_client(self.client, message, self.get_sender_id())
await super().send(message) await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
message_thread_id = None
if self.get_message_type() == MessageType.GROUP_MESSAGE:
user_name = self.message_obj.group_id
else:
user_name = self.get_sender_id()
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
payload = {
"chat_id": user_name,
}
if message_thread_id:
payload["reply_to_message_id"] = message_thread_id
delta = ""
current_content = ""
message_id = None
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
async for chain in generator:
if isinstance(chain, MessageChain):
# 处理消息链中的每个组件
for i in chain.chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload)
continue
elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await self.client.send_document(
document=i.file, filename=i.name, **payload
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload)
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
continue
# Plain
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval:
# 编辑消息
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
else:
# delta 长度一般不会大于 4096因此这里直接发送
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
delta = ""
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
try:
if delta and current_content != delta:
try:
markdown_text = telegramify_markdown.markdownify(
delta, max_line_length=None, normalize_whitespace=False
)
await self.client.edit_message_text(
text=markdown_text,
chat_id=payload["chat_id"],
message_id=message_id,
parse_mode="MarkdownV2",
)
except Exception as e:
logger.warning(f"Markdown转换失败使用普通文本: {e!s}")
await self.client.edit_message_text(
text=delta, chat_id=payload["chat_id"], message_id=message_id
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
return await super().send_streaming(generator, use_fallback)

View File

@@ -17,6 +17,7 @@ from astrbot.core import web_chat_queue
from .webchat_event import WebChatMessageEvent from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter from ...register import register_platform_adapter
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class QueueListener: class QueueListener:
@@ -40,11 +41,11 @@ class WebChatAdapter(Platform):
self.config = platform_config self.config = platform_config
self.settings = platform_settings self.settings = platform_settings
self.unique_session = platform_settings["unique_session"] self.unique_session = platform_settings["unique_session"]
self.imgs_dir = "data/webchat/imgs" self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.imgs_dir, exist_ok=True)
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
"webchat", name="webchat", description="webchat", id=self.config.get("id")
"webchat",
) )
async def send_by_session( async def send_by_session(
@@ -119,3 +120,7 @@ class WebChatAdapter(Platform):
) )
self.commit_event(message_event) self.commit_event(message_event)
async def terminate(self):
# Do nothing
pass

View File

@@ -3,11 +3,12 @@ import uuid
import base64 import base64
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from astrbot.core import web_chat_back_queue from astrbot.core import web_chat_back_queue
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
imgs_dir = "data/webchat/imgs" imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
class WebChatMessageEvent(AstrMessageEvent): class WebChatMessageEvent(AstrMessageEvent):
@@ -16,16 +17,26 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(imgs_dir, exist_ok=True) os.makedirs(imgs_dir, exist_ok=True)
@staticmethod @staticmethod
async def _send(message: MessageChain, session_id: str): async def _send(message: MessageChain, session_id: str, streaming: bool = False):
if not message: if not message:
web_chat_back_queue.put_nowait(None) await web_chat_back_queue.put(
return {"type": "end", "data": "", "streaming": False}
)
return ""
cid = session_id.split("!")[-1] cid = session_id.split("!")[-1]
data = ""
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
web_chat_back_queue.put_nowait((comp.text, cid)) data = comp.text
await web_chat_back_queue.put(
{
"type": "plain",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Image): elif isinstance(comp, Image):
# save image to local # save image to local
filename = str(uuid.uuid4()) + ".jpg" filename = str(uuid.uuid4()) + ".jpg"
@@ -46,11 +57,69 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f: with open(path, "wb") as f:
with open(comp.file, "rb") as f2: with open(comp.file, "rb") as f2:
f.write(f2.read()) f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid)) data = f"[IMAGE]{filename}"
await web_chat_back_queue.put(
{
"type": "image",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Record):
# save record to local
filename = str(uuid.uuid4()) + ".wav"
path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
data = f"[RECORD]{filename}"
await web_chat_back_queue.put(
{
"type": "record",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
else: else:
logger.debug(f"webchat 忽略: {comp.type}") logger.debug(f"webchat 忽略: {comp.type}")
web_chat_back_queue.put_nowait(None)
return data
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id) await WebChatMessageEvent._send(message, session_id=self.session_id)
await web_chat_back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
"cid": self.session_id.split("!")[-1],
}
)
await super().send(message) await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
async for chain in generator:
final_data += await WebChatMessageEvent._send(
chain, session_id=self.session_id, streaming=True
)
await web_chat_back_queue.put(
{
"type": "end",
"data": final_data,
"streaming": True,
"cid": self.session_id.split("!")[-1],
}
)
await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,707 @@
import asyncio
import json
import os
import time
from typing import Optional
import aiohttp
import websockets
from astrbot import logger
from astrbot.api.message_components import Plain, Image
from astrbot.api.platform import Platform, PlatformMetadata
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astrbot_message import (
AstrBotMessage,
MessageMember,
MessageType,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .wechatpadpro_message_event import WeChatPadProMessageEvent
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
class WeChatPadProAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self._shutdown_event = None
self.wxnewpass = None
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
self.metadata = PlatformMetadata(
name="wechatpadpro",
description="WeChatPadPro 消息平台适配器",
id=self.config.get("id", "wechatpadpro"),
)
# 保存配置信息
self.admin_key = self.config.get("admin_key")
self.host = self.config.get("host")
self.port = self.config.get("port")
self.active_mesasge_poll: bool = self.config.get(
"wpp_active_message_poll", False
)
self.active_message_poll_interval: int = self.config.get(
"wpp_active_message_poll_interval", 5
)
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(), "wechatpadpro_credentials.json"
) # 持久化文件路径
self.ws_handle_task = None
async def run(self) -> None:
"""
启动平台适配器的运行实例。
"""
logger.info("WeChatPadPro 适配器正在启动...")
if loaded_credentials := self.load_credentials():
self.auth_key = loaded_credentials.get("auth_key")
self.wxid = loaded_credentials.get("wxid")
isLoginIn = await self.check_online_status()
# 检查在线状态
if self.auth_key and isLoginIn:
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
# 如果在线,连接 WebSocket 接收消息
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
else:
# 1. 生成授权码
if not self.auth_key:
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码。")
await self.generate_auth_key()
# 2. 获取登录二维码
if not isLoginIn:
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
qr_code_url = await self.get_login_qr_code()
if qr_code_url:
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
else:
logger.error("无法获取登录二维码。")
return
# 3. 检测扫码状态
login_successful = await self.check_login_status()
if login_successful:
logger.info("登录成功WeChatPadPro适配器已连接。")
else:
logger.warning("登录失败或超时WeChatPadPro 适配器将关闭。")
await self.terminate()
return
# 登录成功后,连接 WebSocket 接收消息
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
self._shutdown_event = asyncio.Event()
await self._shutdown_event.wait()
logger.info("WeChatPadPro 适配器已停止。")
def load_credentials(self):
"""
从文件中加载 auth_key 和 wxid。
"""
if os.path.exists(self.credentials_file):
try:
with open(self.credentials_file, "r") as f:
credentials = json.load(f)
logger.info("成功加载 WeChatPadPro 凭据。")
return credentials
except Exception as e:
logger.error(f"加载 WeChatPadPro 凭据失败: {e}")
return None
def save_credentials(self):
"""
将 auth_key 和 wxid 保存到文件。
"""
credentials = {
"auth_key": self.auth_key,
"wxid": self.wxid,
}
try:
# 确保数据目录存在
data_dir = os.path.dirname(self.credentials_file)
os.makedirs(data_dir, exist_ok=True)
with open(self.credentials_file, "w") as f:
json.dump(credentials, f)
logger.info("成功保存 WeChatPadPro 凭据。")
except Exception as e:
logger.error(f"保存 WeChatPadPro 凭据失败: {e}")
async def check_online_status(self):
"""
检查 WeChatPadPro 设备是否在线。
"""
url = f"{self.base_url}/login/GetLoginStatus"
params = {"key": self.auth_key}
async with aiohttp.ClientSession() as session:
try:
async with session.get(url, params=params) as response:
response_data = await response.json()
# 根据提供的在线接口返回示例,成功状态码是 200loginState 为 1 表示在线
if response.status == 200 and response_data.get("Code") == 200:
login_state = response_data.get("Data", {}).get("loginState")
if login_state == 1:
logger.info("WeChatPadPro 设备当前在线。")
return True
# login_state == 3 为离线状态
elif login_state == 3:
logger.info(
"WeChatPadPro 设备不在线。"
)
return False
else:
logger.error(
f"未知的在线状态: {login_state:}"
)
return False
# Code == 300 为微信退出状态。
elif response.status == 200 and response_data.get("Code") == 300:
logger.info(
"WeChatPadPro 设备已退出。"
)
return False
else:
logger.error(
f"检查在线状态失败: {response.status}, {response_data}"
)
return False
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return False
except Exception as e:
logger.error(f"检查在线状态时发生错误: {e}")
return False
async def generate_auth_key(self):
"""
生成授权码。
"""
url = f"{self.base_url}/admin/GenAuthKey1"
params = {"key": self.admin_key}
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
response_data = await response.json()
# 修正成功判断条件和授权码提取路径
if response.status == 200 and response_data.get("Code") == 200:
# 授权码在 Data 字段的列表中
if (
response_data.get("Data")
and isinstance(response_data["Data"], list)
and len(response_data["Data"]) > 0
):
self.auth_key = response_data["Data"][0]
logger.info("成功获取授权码")
else:
logger.error(
f"生成授权码成功但未找到授权码: {response_data}"
)
else:
logger.error(
f"生成授权码失败: {response.status}, {response_data}"
)
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
except Exception as e:
logger.error(f"生成授权码时发生错误: {e}")
async def get_login_qr_code(self):
"""
获取登录二维码地址。
"""
url = f"{self.base_url}/login/GetLoginQrCodeNew"
params = {"key": self.auth_key}
payload = {} # 根据文档,这个接口的 body 可以为空
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
response_data = await response.json()
# 修正成功判断条件和数据提取路径
if response.status == 200 and response_data.get("Code") == 200:
# 二维码地址在 Data.QrCodeUrl 字段中
if response_data.get("Data") and response_data["Data"].get(
"QrCodeUrl"
):
return response_data["Data"]["QrCodeUrl"]
else:
logger.error(
f"获取登录二维码成功但未找到二维码地址: {response_data}"
)
return None
else:
logger.error(
f"获取登录二维码失败: {response.status}, {response_data}"
)
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取登录二维码时发生错误: {e}")
return None
async def check_login_status(self):
"""
循环检测扫码状态。
尝试 6 次后跳出循环,添加倒计时。
返回 True 如果登录成功,否则返回 False。
"""
url = f"{self.base_url}/login/CheckLoginStatus"
params = {"key": self.auth_key}
attempts = 0 # 初始化尝试次数
max_attempts = 36 # 最大尝试次数
countdown = 180 # 倒计时时长
logger.info(f"请在 {countdown} 秒内扫码登录。")
while attempts < max_attempts:
async with aiohttp.ClientSession() as session:
try:
async with session.get(url, params=params) as response:
response_data = await response.json()
# 成功判断条件和数据提取路径
if response.status == 200 and response_data.get("Code") == 200:
if (
response_data.get("Data")
and response_data["Data"].get("state") is not None
):
status = response_data["Data"]["state"]
logger.info(
f"{attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}"
)
if status == 2: # 状态 2 表示登录成功
self.wxid = response_data["Data"].get("wxid")
self.wxnewpass = response_data["Data"].get(
"wxnewpass"
)
logger.info(
f"登录成功wxid: {self.wxid}, wxnewpass: {self.wxnewpass}"
)
self.save_credentials() # 登录成功后保存凭据
return True
elif status == -2: # 二维码过期
logger.error("二维码已过期,请重新获取。")
return False
else:
logger.error(
f"检测登录状态成功但未找到登录状态: {response_data}"
)
elif response_data.get("Code") == 300:
# "不存在状态"
pass
else:
logger.info(
f"检测登录状态失败: {response.status}, {response_data}"
)
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
await asyncio.sleep(5)
attempts += 1
continue
except Exception as e:
logger.error(f"检测登录状态时发生错误: {e}")
attempts += 1
continue
attempts += 1
await asyncio.sleep(5) # 每隔5秒检测一次
logger.warning("登录检测超过最大尝试次数,退出检测。")
return False
async def connect_websocket(self):
"""
建立 WebSocket 连接并处理接收到的消息。
"""
os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}"
ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}"
logger.info(
f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***"
)
while True:
try:
async with websockets.connect(ws_url) as websocket:
logger.info("WebSocket 连接成功。")
# 设置空闲超时重连
wait_time = (
self.active_message_poll_interval
if self.active_mesasge_poll
else 120
)
while True:
try:
message = await asyncio.wait_for(
websocket.recv(), timeout=wait_time
)
# logger.debug(message) # 不显示原始消息内容
asyncio.create_task(self.handle_websocket_message(message))
except asyncio.TimeoutError:
logger.warning(f"WebSocket 连接空闲超过 {wait_time} s")
break
except websockets.exceptions.ConnectionClosedOK:
logger.info("WebSocket 连接正常关闭。")
break
except Exception as e:
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
break
except Exception as e:
logger.error(f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态或尝试重启WeChatPadPro适配器。")
await asyncio.sleep(5)
async def handle_websocket_message(self, message: str):
"""
处理从 WebSocket 接收到的消息。
"""
logger.debug(f"收到 WebSocket 消息: {message}")
try:
message_data = json.loads(message)
if (
message_data.get("msg_id") is not None
and message_data.get("from_user_name") is not None
):
abm = await self.convert_message(message_data)
if abm:
# 创建 WeChatPadProMessageEvent 实例
message_event = WeChatPadProMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
# 传递适配器实例,以便在事件中调用 send 方法
adapter=self,
)
# 提交事件到事件队列
self.commit_event(message_event)
else:
logger.warning(f"收到未知结构的 WebSocket 消息: {message_data}")
except json.JSONDecodeError:
logger.error(f"无法解析 WebSocket 消息为 JSON: {message}")
except Exception as e:
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""
将 WeChatPadPro 原始消息转换为 AstrBotMessage。
"""
abm = AstrBotMessage()
abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = raw_message.get("create_time")
abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180:
logger.warning(
f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}"
)
return None
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "")
msg_type = raw_message.get("msg_type")
abm.message_str = ""
abm.message = []
# 如果是机器人自己发送的消息、回显消息或系统消息,忽略
if from_user_name == self.wxid:
logger.info("忽略来自自己的消息。")
return None
if from_user_name in ["weixin", "newsapp", "newsapp_wechat"]:
logger.info("忽略来自微信团队的消息。")
return None
# 先判断群聊/私聊并设置基本属性
if await self._process_chat_type(
abm, raw_message, from_user_name, to_user_name, content, push_content
):
# 再根据消息类型处理消息内容
await self._process_message_content(abm, raw_message, msg_type, content)
return abm
return None
async def _process_chat_type(
self,
abm: AstrBotMessage,
raw_message: dict,
from_user_name: str,
to_user_name: str,
content: str,
push_content: str,
):
"""
判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。
"""
if from_user_name == "weixin":
return False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = from_user_name
parts = content.split(":\n", 1)
sender_wxid = parts[0] if len(parts) == 2 else ""
abm.sender = MessageMember(user_id=sender_wxid, nickname="")
# 获取群聊发送者的nickname
if sender_wxid:
accurate_nickname = await self._get_group_member_nickname(
abm.group_id, sender_wxid
)
if accurate_nickname:
abm.sender.nickname = accurate_nickname
# 对于群聊session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
if self.unique_session:
abm.session_id = f"{from_user_name}_{to_user_name}"
else:
abm.session_id = from_user_name
else:
abm.type = MessageType.FRIEND_MESSAGE
abm.group_id = ""
nick_name = ""
if push_content and " : " in push_content:
nick_name = push_content.split(" : ")[0]
abm.sender = MessageMember(user_id=from_user_name, nickname=nick_name)
abm.session_id = from_user_name
return True
async def _get_group_member_nickname(
self, group_id: str, member_wxid: str
) -> Optional[str]:
"""
通过接口获取群成员的昵称。
"""
url = f"{self.base_url}/group/GetChatroomMemberDetail"
params = {"key": self.auth_key}
payload = {
"ChatRoomName": group_id,
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
response_data = await response.json()
if response.status == 200 and response_data.get("Code") == 200:
# 从返回数据中查找对应成员的昵称
member_list = (
response_data.get("Data", {})
.get("member_data", {})
.get("chatroom_member_list", [])
)
for member in member_list:
if member.get("user_name") == member_wxid:
return member.get("nick_name")
logger.warning(
f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称"
)
else:
logger.error(
f"获取群成员详情失败: {response.status}, {response_data}"
)
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取群成员详情时发生错误: {e}")
return None
async def _download_raw_image(
self, from_user_name: str, to_user_name: str, msg_id: int
):
"""下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key}
payload = {
"CompressType": 0,
"FromUserName": from_user_name,
"MsgId": msg_id,
"Section": {"DataLen": 61440, "StartPos": 0},
"ToUserName": to_user_name,
"TotalLen": 0,
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status == 200:
return await response.json()
else:
logger.error(f"下载图片失败: {response.status}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"下载图片时发生错误: {e}")
return None
async def _process_message_content(
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
):
"""
根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。
"""
if msg_type == 1: # 文本消息
abm.message_str = content
if abm.type == MessageType.GROUP_MESSAGE:
parts = content.split(":\n", 1)
if len(parts) == 2:
abm.message_str = parts[1]
abm.message.append(Plain(abm.message_str))
else:
abm.message.append(Plain(abm.message_str))
else: # 私聊消息
abm.message.append(Plain(abm.message_str))
elif msg_type == 3:
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = raw_message.get("msg_id")
image_resp = await self._download_raw_image(
from_user_name, to_user_name, msg_id
)
image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
)
if image_bs64_data:
abm.message.append(Image.fromBase64(image_bs64_data))
elif msg_type == 47:
# 视频消息 (注意:表情消息也是 47需要区分)
logger.warning("收到视频消息,待实现。")
elif msg_type == 50:
# 语音/视频
logger.warning("收到语音/视频消息,待实现。")
elif msg_type == 49:
# 引用消息
logger.warning("收到引用消息,待实现。")
else:
logger.warning(f"收到未处理的消息类型: {msg_type}")
async def terminate(self):
"""
终止一个平台的运行实例。
"""
logger.info("终止 WeChatPadPro 适配器。")
try:
if self.ws_handle_task:
self.ws_handle_task.cancel()
self._shutdown_event.set()
except Exception:
pass
def meta(self) -> PlatformMetadata:
"""
得到一个平台的元数据。
"""
return self.metadata
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
dummy_message_obj = AstrBotMessage()
dummy_message_obj.session_id = session.session_id
# 根据 session_id 判断消息类型
if "@chatroom" in session.session_id:
dummy_message_obj.type = MessageType.GROUP_MESSAGE
dummy_message_obj.group_id = session.session_id
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
else:
dummy_message_obj.type = MessageType.FRIEND_MESSAGE
dummy_message_obj.group_id = ""
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
sending_event = WeChatPadProMessageEvent(
message_str="",
message_obj=dummy_message_obj,
platform_meta=self.meta(),
session_id=session.session_id,
adapter=self,
)
# 调用实例方法 send
await sending_event.send(message_chain)
async def get_contact_list(self):
"""
获取联系人列表。
"""
url = f"{self.base_url}/friend/GetContactList"
params = {"key": self.auth_key}
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = (
result.get("Data", {})
.get("ContactList", {})
.get("contactUsernameList", [])
)
return contact_list
else:
logger.error(f"获取联系人列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人列表时发生错误: {e}")
return None
async def get_contact_details_list(
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
) -> Optional[dict]:
"""
获取联系人详情列表。
"""
if room_wx_id_list is None:
room_wx_id_list = []
if user_names is None:
user_names = []
url = f"{self.base_url}/friend/GetContactDetailsList"
params = {"key": self.auth_key}
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人详情列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = result.get("Data", {}).get("contactList", {})
return contact_list
else:
logger.error(f"获取联系人详情列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人详情列表时发生错误: {e}")
return None

View File

@@ -0,0 +1,117 @@
import asyncio
import base64
import io
from typing import TYPE_CHECKING
import aiohttp
from PIL import Image as PILImage # 使用别名避免冲突
from astrbot import logger
from astrbot.core.message.components import Image, Plain # Import Image
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata
if TYPE_CHECKING:
from .wechatpadpro_adapter import WeChatPadProAdapter
class WeChatPadProMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
adapter: "WeChatPadProAdapter", # 传递适配器实例
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.message_obj = message_obj # Save the full message object
self.adapter = adapter # Save the adapter instance
async def send(self, message: MessageChain):
async with aiohttp.ClientSession() as session:
for comp in message.chain:
await asyncio.sleep(1)
if isinstance(comp, Plain):
await self._send_text(session, comp.text)
elif isinstance(comp, Image):
await self._send_image(session, comp)
await super().send(message)
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
b64 = await comp.convert_to_base64()
raw = self._validate_base64(b64)
b64c = self._compress_image(raw)
payload = {
"MsgItem": [
{"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id}
]
}
url = f"{self.adapter.base_url}/message/SendImageNewMessage"
await self._post(session, url, payload)
async def _send_text(self, session: aiohttp.ClientSession, text: str):
if (
self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息
and self.adapter.settings.get(
"reply_with_mention", False
) # 检查适配器设置是否启用 reply_with_mention
and self.message_obj.sender # 确保有发送者信息
and (
self.message_obj.sender.user_id or self.message_obj.sender.nickname
) # 确保发送者有 ID 或昵称
):
# 优先使用 nickname如果没有则使用 user_id
mention_text = (
self.message_obj.sender.nickname or self.message_obj.sender.user_id
)
message_text = f"@{mention_text} {text}"
# logger.info(f"已添加 @ 信息: {message_text}")
else:
message_text = text
payload = {
"MsgItem": [
{"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id}
]
}
url = f"{self.adapter.base_url}/message/SendTextMessage"
await self._post(session, url, payload)
@staticmethod
def _validate_base64(b64: str) -> bytes:
return base64.b64decode(b64, validate=True)
@staticmethod
def _compress_image(data: bytes) -> str:
img = PILImage.open(io.BytesIO(data))
buf = io.BytesIO()
if img.format == "JPEG":
img.save(buf, "JPEG", quality=80)
else:
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
img.save(buf, "JPEG", quality=80)
# logger.info("图片处理完成!!!")
return base64.b64encode(buf.getvalue()).decode()
async def _post(self, session, url, payload):
params = {"key": self.adapter.auth_key}
try:
async with session.post(url, params=params, json=payload) as resp:
data = await resp.json()
if resp.status != 200 or data.get("Code") != 200:
logger.error(f"{url} failed: {resp.status} {data}")
except Exception as e:
logger.error(f"{url} error: {e}")
# TODO: 添加对其他消息组件类型的处理 (Record, Video, At等)
# elif isinstance(component, Record):
# pass
# elif isinstance(component, Video):
# pass
# elif isinstance(component, At):
# pass
# ...

View File

@@ -1,28 +1,33 @@
import asyncio
import os
import sys import sys
import uuid import uuid
import asyncio
import quart
import quart
from requests import Response
from wechatpy.enterprise import WeChatClient, parse_message
from wechatpy.enterprise.crypto import WeChatCrypto
from wechatpy.enterprise.messages import ImageMessage, TextMessage, VoiceMessage
from wechatpy.exceptions import InvalidSignatureException
from wechatpy.messages import BaseMessage
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.platform import ( from astrbot.api.platform import (
Platform,
AstrBotMessage, AstrBotMessage,
MessageMember, MessageMember,
PlatformMetadata,
MessageType, MessageType,
Platform,
PlatformMetadata,
register_platform_adapter,
) )
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
from astrbot.core import logger from astrbot.core import logger
from requests import Response from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from wechatpy.enterprise.crypto import WeChatCrypto
from wechatpy.enterprise import WeChatClient
from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
from wechatpy.exceptions import InvalidSignatureException
from wechatpy.enterprise import parse_message
from .wecom_event import WecomPlatformEvent from .wecom_event import WecomPlatformEvent
from .wecom_kf import WeChatKF
from .wecom_kf_message import WeChatKFMessage
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from typing import override from typing import override
@@ -34,6 +39,7 @@ class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict): def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.port = int(config.get("port")) self.port = int(config.get("port"))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule( self.server.add_url_rule(
"/callback/command", view_func=self.verify, methods=["GET"] "/callback/command", view_func=self.verify, methods=["GET"]
) )
@@ -49,6 +55,7 @@ class WecomServer:
) )
self.callback = None self.callback = None
self.shutdown_event = asyncio.Event()
async def verify(self): async def verify(self):
logger.info(f"验证请求有效性: {quart.request.args}") logger.info(f"验证请求有效性: {quart.request.args}")
@@ -86,17 +93,17 @@ class WecomServer:
return "success" return "success"
async def start_polling(self): async def start_polling(self):
logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。") logger.info(
f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。"
)
await self.server.run_task( await self.server.run_task(
host="0.0.0.0", host=self.callback_server_host,
port=self.port, port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder, shutdown_trigger=self.shutdown_trigger,
) )
async def shutdown_trigger_placeholder(self): async def shutdown_trigger(self):
while not self.event_queue.closed: # noqa: ASYNC110 await self.shutdown_event.wait()
await asyncio.sleep(1)
logger.info("企业微信 适配器已关闭。")
@register_platform_adapter("wecom", "wecom 适配器") @register_platform_adapter("wecom", "wecom 适配器")
@@ -129,9 +136,40 @@ class WecomPlatformAdapter(Platform):
self.config["corpid"].strip(), self.config["corpid"].strip(),
self.config["secret"].strip(), self.config["secret"].strip(),
) )
# 微信客服
self.kf_name = self.config.get("kf_name", None)
if self.kf_name:
# inject
self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client)
self.client.kf = self.wechat_kf_api
self.client.kf_message = self.wechat_kf_message_api
self.client.API_BASE_URL = self.api_base_url self.client.API_BASE_URL = self.api_base_url
async def callback(msg): async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
def get_latest_msg_item() -> dict | None:
token = msg._data["Token"]
kfid = msg._data["OpenKfId"]
has_more = 1
ret = {}
while has_more:
ret = self.wechat_kf_api.sync_msg(token, kfid)
has_more = ret["has_more"]
msg_list = ret.get("msg_list", [])
if msg_list:
return msg_list[-1]
return None
msg_new = await asyncio.get_event_loop().run_in_executor(
None, get_latest_msg_item
)
if msg_new:
await self.convert_wechat_kf_message(msg_new)
return
await self.convert_message(msg) await self.convert_message(msg)
self.server.callback = callback self.server.callback = callback
@@ -151,9 +189,39 @@ class WecomPlatformAdapter(Platform):
@override @override
async def run(self): async def run(self):
loop = asyncio.get_event_loop()
if self.kf_name:
try:
acc_list = (
await loop.run_in_executor(
None, self.wechat_kf_api.get_account_list
)
).get("account_list", [])
logger.debug(f"获取到微信客服列表: {str(acc_list)}")
for acc in acc_list:
name = acc.get("name", None)
if name != self.kf_name:
continue
open_kfid = acc.get("open_kfid", None)
if not open_kfid:
logger.error("获取微信客服失败open_kfid 为空。")
logger.debug(f"Found open_kfid: {str(open_kfid)}")
kf_url = (
await loop.run_in_executor(
None,
self.wechat_kf_api.add_contact_way,
open_kfid,
"astrbot_placeholder",
)
).get("url", "")
logger.info(
f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}"
)
except Exception as e:
logger.error(e)
await self.server.start_polling() await self.server.start_polling()
async def convert_message(self, msg): async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage() abm = AstrBotMessage()
if msg.type == "text": if msg.type == "text":
assert isinstance(msg, TextMessage) assert isinstance(msg, TextMessage)
@@ -189,14 +257,15 @@ class WecomPlatformAdapter(Platform):
resp: Response = await asyncio.get_event_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, self.client.media.download, msg.media_id None, self.client.media.download, msg.media_id
) )
path = f"data/temp/wecom_{msg.media_id}.amr" temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(resp.content) f.write(resp.content)
try: try:
from pydub import AudioSegment from pydub import AudioSegment
path_wav = f"data/temp/wecom_{msg.media_id}.wav" path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
audio = AudioSegment.from_file(path) audio = AudioSegment.from_file(path)
audio.export(path_wav, format="wav") audio.export(path_wav, format="wav")
except Exception as e: except Exception as e:
@@ -216,10 +285,43 @@ class WecomPlatformAdapter(Platform):
abm.timestamp = msg.time abm.timestamp = msg.time
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
else:
logger.warning(f"暂未实现的事件: {msg.type}")
return
logger.info(f"abm: {abm}") logger.info(f"abm: {abm}")
await self.handle_msg(abm) await self.handle_msg(abm)
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype", None)
external_userid = msg.get("external_userid", None)
abm = AstrBotMessage()
abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
abm.self_id = msg["open_kfid"]
abm.sender = MessageMember(external_userid, external_userid)
abm.session_id = external_userid
abm.type = MessageType.FRIEND_MESSAGE
abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8])
if msgtype == "text":
text = msg.get("text", {}).get("content", "").strip()
abm.message = [Plain(text=text)]
abm.message_str = text
elif msgtype == "image":
media_id = msg.get("image", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
None, self.client.media.download, media_id
)
path = f"data/temp/wechat_kf_{media_id}.jpg"
with open(path, "wb") as f:
f.write(resp.content)
abm.message = [Image(file=path, url=path)]
abm.message_str = "[图片]"
else:
logger.warning(f"未实现的微信客服消息事件: {msg}")
return
await self.handle_msg(abm)
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
message_event = WecomPlatformEvent( message_event = WecomPlatformEvent(
message_str=message.message_str, message_str=message.message_str,
@@ -232,3 +334,11 @@ class WecomPlatformAdapter(Platform):
def get_client(self) -> WeChatClient: def get_client(self) -> WeChatClient:
return self.client return self.client
async def terminate(self):
self.server.shutdown_event.set()
try:
await self.server.server.shutdown()
except Exception as _:
pass
logger.info("企业微信 适配器已被优雅地关闭")

View File

@@ -1,11 +1,14 @@
import os
import uuid import uuid
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record from astrbot.api.message_components import Plain, Image, Record
from wechatpy.enterprise import WeChatClient from wechatpy.enterprise import WeChatClient
from astrbot.core.utils.io import download_image_by_url, download_file from .wecom_kf_message import WeChatKFMessage
from astrbot.api import logger from astrbot.api import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
try: try:
import pydub import pydub
@@ -34,70 +37,158 @@ class WecomPlatformEvent(AstrMessageEvent):
): ):
pass pass
async def split_plain(self, plain: str) -> list[str]:
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
Args:
plain (str): 要分割的长文本
Returns:
list[str]: 分割后的文本列表
"""
if len(plain) <= 2048:
return [plain]
else:
result = []
start = 0
while start < len(plain):
# 剩下的字符串长度<2048时结束
if start + 2048 >= len(plain):
result.append(plain[start:])
break
# 向前搜索分割标点符号
end = min(start + 2048, len(plain))
cut_position = end
for i in range(end, start, -1):
if i < len(plain) and plain[i - 1] in [
"",
"",
"",
".",
"!",
"?",
"\n",
";",
"",
]:
cut_position = i
break
# 没找到合适的位置分割, 直接切分
if cut_position == end and end < len(plain):
cut_position = end
result.append(plain[start:cut_position])
start = cut_position
return result
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
message_obj = self.message_obj message_obj = self.message_obj
for comp in message.chain: is_wechat_kf = hasattr(self.client, "kf_message")
if isinstance(comp, Plain): if is_wechat_kf:
self.client.message.send_text( # 微信客服
message_obj.self_id, message_obj.session_id, comp.text kf_message_api = getattr(self.client, "kf_message", None)
) if not kf_message_api:
elif isinstance(comp, Image): logger.warning("未找到微信客服发送消息方法。")
img_url = comp.file return
img_path = "" assert isinstance(kf_message_api, WeChatKFMessage)
if img_url.startswith("file:///"): user_id = self.get_sender_id()
img_path = img_url[8:] for comp in message.chain:
elif comp.file and comp.file.startswith("http"): if isinstance(comp, Plain):
img_path = await download_image_by_url(comp.file) # Split long text messages if needed
else: plain_chunks = await self.split_plain(comp.text)
img_path = img_url for chunk in plain_chunks:
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
await asyncio.sleep(0.5) # Avoid sending too fast
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
with open(img_path, "rb") as f: with open(img_path, "rb") as f:
try: try:
response = self.client.media.upload("image", f) response = self.client.media.upload("image", f)
except Exception as e: except Exception as e:
logger.error(f"企业微信上传图片失败: {e}") logger.error(f"微信客服上传图片失败: {e}")
await self.send( await self.send(
MessageChain().message(f"企业微信上传图片失败: {e}") MessageChain().message(f"微信客服上传图片失败: {e}")
)
return
logger.debug(f"微信客服上传图片返回: {response}")
kf_message_api.send_image(
user_id,
self.get_self_id(),
response["media_id"],
) )
return
logger.info(f"企业微信上传图片返回: {response}")
self.client.message.send_image(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
elif isinstance(comp, Record):
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else: else:
record_path = record_url logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")
else:
# 转成amr # 企业微信应用
record_path_amr = f"data/temp/{uuid.uuid4()}.amr" for comp in message.chain:
pydub.AudioSegment.from_wav(record_path).export( if isinstance(comp, Plain):
record_path_amr, format="amr" # Split long text messages if needed
) plain_chunks = await self.split_plain(comp.text)
for chunk in plain_chunks:
with open(record_path_amr, "rb") as f: self.client.message.send_text(
try: message_obj.self_id, message_obj.session_id, chunk
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"企业微信上传语音失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传语音失败: {e}")
) )
return await asyncio.sleep(0.5) # Avoid sending too fast
logger.info(f"企业微信上传语音返回: {response}") elif isinstance(comp, Image):
self.client.message.send_voice( img_path = await comp.convert_to_file_path()
message_obj.self_id,
message_obj.session_id, with open(img_path, "rb") as f:
response["media_id"], try:
response = self.client.media.upload("image", f)
except Exception as e:
logger.error(f"企业微信上传图片失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传图片失败: {e}")
)
return
logger.debug(f"企业微信上传图片返回: {response}")
self.client.message.send_image(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
# 转成amr
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
pydub.AudioSegment.from_wav(record_path).export(
record_path_amr, format="amr"
) )
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"企业微信上传语音失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传语音失败: {e}")
)
return
logger.info(f"企业微信上传语音返回: {response}")
self.client.message.send_voice(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
else:
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")
await super().send(message) await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)

View File

@@ -0,0 +1,278 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2014-2020 messense
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKF(BaseWeChatAPI):
"""
微信客服接口
https://work.weixin.qq.com/api/doc/90000/90135/94670
"""
def sync_msg(self, token, open_kfid, cursor="", limit=1000):
"""
微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收)
、客户点击菜单消息的回复消息,可以通过该接口获取具体的消息内容和事件。不支持读取通过发送消息接口发送的消息。
支持的消息类型:文本、图片、语音、视频、文件、位置、链接、名片、小程序、事件。
:param token: 回调事件返回的token字段10分钟内有效可不填如果不填接口有严格的频率限制。不多于128字节
:param open_kfid: 客服帐号ID
:param cursor: 上一次调用时返回的next_cursor第一次拉取可以不填。不多于64字节
:param limit: 期望请求的数据量默认值和最大值都为1000。
注意可能会出现返回条数少于limit的情况需结合返回的has_more字段判断是否继续请求。
:return: 接口调用结果
"""
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
return self._post("kf/sync_msg", data=data)
def get_service_state(self, open_kfid, external_userid):
"""
获取会话状态
ID 状态 说明
0 未处理 新会话接入。可选择1.直接用API自动回复消息。2.放进待接入池等待接待人员接待。3.指定接待人员进行接待
1 由智能助手接待 可使用API回复消息。可选择转入待接入池或者指定接待人员处理。
2 待接入池排队中 在待接入池中排队等待接待人员接入。可选择转为指定人员接待
3 由人工接待 人工接待中。可选择结束会话
4 已结束 会话已经结束。不允许变更会话状态,等待用户重新发起咨询
:param open_kfid: 客服帐号ID
:param external_userid: 微信客户的external_userid
:return: 接口调用结果
"""
data = {
"open_kfid": open_kfid,
"external_userid": external_userid,
}
return self._post("kf/service_state/get", data=data)
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
"""
变更会话状态
:param open_kfid: 客服帐号ID
:param external_userid: 微信客户的external_userid
:param service_state: 当前的会话状态,状态定义参考概述中的表格
:return: 接口调用结果
"""
data = {
"open_kfid": open_kfid,
"external_userid": external_userid,
"service_state": service_state,
}
if servicer_userid:
data["servicer_userid"] = servicer_userid
return self._post("kf/service_state/trans", data=data)
def get_servicer_list(self, open_kfid):
"""
获取接待人员列表
:param open_kfid: 客服帐号ID
:return: 接口调用结果
"""
data = {
"open_kfid": open_kfid,
}
return self._get("kf/servicer/list", params=data)
def add_servicer(self, open_kfid, userid_list):
"""
添加接待人员
添加指定客服帐号的接待人员。
:param open_kfid: 客服帐号ID
:param userid_list: 接待人员userid列表
:return: 接口调用结果
"""
if not isinstance(userid_list, list):
userid_list = [userid_list]
data = {
"open_kfid": open_kfid,
"userid_list": userid_list,
}
return self._post("kf/servicer/add", data=data)
def del_servicer(self, open_kfid, userid_list):
"""
删除接待人员
从客服帐号删除接待人员
:param open_kfid: 客服帐号ID
:param userid_list: 接待人员userid列表
:return: 接口调用结果
"""
if not isinstance(userid_list, list):
userid_list = [userid_list]
data = {
"open_kfid": open_kfid,
"userid_list": userid_list,
}
return self._post("kf/servicer/del", data=data)
def batchget_customer(self, external_userid_list):
"""
客户基本信息获取
:param external_userid_list: external_userid列表
:return: 接口调用结果
"""
if not isinstance(external_userid_list, list):
external_userid_list = [external_userid_list]
data = {
"external_userid_list": external_userid_list,
}
return self._post("kf/customer/batchget", data=data)
def get_account_list(self):
"""
获取客服帐号列表
:return: 接口调用结果
"""
return self._get("kf/account/list")
def add_contact_way(self, open_kfid, scene):
"""
获取客服帐号链接
:param open_kfid: 客服帐号ID
:param scene: 场景值字符串类型由开发者自定义。不多于32字节;字符串取值范围(正则表达式)[0-9a-zA-Z_-]*
:return: 接口调用结果
"""
data = {"open_kfid": open_kfid, "scene": scene}
return self._post("kf/add_contact_way", data=data)
def get_upgrade_service_config(self):
"""
获取配置的专员与客户群
:return: 接口调用结果
"""
return self._get("kf/customer/get_upgrade_service_config")
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
"""
为客户升级为专员或客户群服务
:param open_kfid: 客服帐号ID
:param external_userid: 微信客户的external_userid
:param service_type: 表示是升级到专员服务还是客户群服务。1:专员服务。2:客户群服务
:param member: 推荐的服务专员type等于1时有效
:param groupchat: 推荐的客户群type等于2时有效
:return: 接口调用结果
"""
data = {
"open_kfid": open_kfid,
"external_userid": external_userid,
"type": service_type,
}
if service_type == 1:
data["member"] = member
else:
data["groupchat"] = groupchat
return self._post("kf/customer/upgrade_service", data=data)
def cancel_upgrade_service(self, open_kfid, external_userid):
"""
为客户取消推荐
:param open_kfid: 客服帐号ID
:param external_userid: 微信客户的external_userid
:return: 接口调用结果
"""
data = {"open_kfid": open_kfid, "external_userid": external_userid}
return self._post("kf/customer/cancel_upgrade_service", data=data)
def send_msg_on_event(self, code, msgtype, msg_content, msgid=None):
"""
当特定的事件回调消息包含code字段可以此code为凭证调用该接口给用户发送相应事件场景下的消息如客服欢迎语。
支持发送消息类型:文本、菜单消息。
:param code: 事件响应消息对应的code。通过事件回调下发仅可使用一次。
:param msgtype: 消息类型。对不同的msgtype有相应的结构描述详见消息类型
:param msg_content: 目前支持文本与菜单消息,具体查看文档
:param msgid: 消息ID。如果请求参数指定了msgid则原样返回否则系统自动生成并返回。不多于32字节
字符串取值范围(正则表达式)[0-9a-zA-Z_-]*
:return: 接口调用结果
"""
data = {"code": code, "msgtype": msgtype}
if msgid:
data["msgid"] = msgid
data.update(msg_content)
return self._post("kf/send_msg_on_event", data=data)
def get_corp_statistic(self, start_time, end_time, open_kfid=None):
"""
获取「客户数据统计」企业汇总数据
:param start_time: 开始时间
:param end_time: 结束时间
:param open_kfid: 客服帐号ID
:return: 接口调用结果
"""
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
return self._post("kf/get_corp_statistic", data=data)
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
"""
获取「客户数据统计」接待人员明细数据
:param start_time: 开始时间
:param end_time: 结束时间
:param open_kfid: 客服帐号ID
:param servicer_userid: 接待人员
:return: 接口调用结果
"""
data = {
"open_kfid": open_kfid,
"servicer_userid": servicer_userid,
"start_time": start_time,
"end_time": end_time,
}
return self._post("kf/get_servicer_statistic", data=data)
def account_update(self, open_kfid, name, media_id):
"""
修改客服账号
:param open_kfid: 客服帐号ID
:param name: 客服名称
:param media_id: 客服头像临时素材
:return: 接口调用结果
"""
data = {"open_kfid": open_kfid, "name": name, "media_id": media_id}
return self._post("kf/account/update", data=data)

View File

@@ -0,0 +1,159 @@
"""
The MIT License (MIT)
Copyright (c) 2014-2020 messense
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from optionaldict import optionaldict
from wechatpy.client.api.base import BaseWeChatAPI
class WeChatKFMessage(BaseWeChatAPI):
"""
发送微信客服消息
https://work.weixin.qq.com/api/doc/90000/90135/94677
支持:
* 文本消息
* 图片消息
* 语音消息
* 视频消息
* 文件消息
* 图文链接
* 小程序
* 菜单消息
* 地理位置
"""
def send(self, user_id, open_kfid, msgid="", msg=None):
"""
当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。
注意仅当微信客户在主动发送消息给客服后的48小时内企业可发送消息给客户最多可发送5条消息若用户继续发送消息企业可再次下发消息。
支持发送消息类型:文本、图片、语音、视频、文件、图文、小程序、菜单消息、地理位置。
:param user_id: 指定接收消息的客户UserID
:param open_kfid: 指定发送消息的客服帐号ID
:param msgid: 指定消息ID
:param tag_ids: 标签ID列表。
:param msg: 发送消息的 dict 对象
:type msg: dict | None
:return: 接口调用结果
"""
msg = msg or {}
data = {
"touser": user_id,
"open_kfid": open_kfid,
}
if msgid:
data["msgid"] = msgid
data.update(msg)
return self._post("kf/send_msg", data=data)
def send_text(self, user_id, open_kfid, content, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={"msgtype": "text", "text": {"content": content}},
)
def send_image(self, user_id, open_kfid, media_id, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={"msgtype": "image", "image": {"media_id": media_id}},
)
def send_voice(self, user_id, open_kfid, media_id, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={"msgtype": "voice", "voice": {"media_id": media_id}},
)
def send_video(self, user_id, open_kfid, media_id, msgid=""):
video_data = optionaldict()
video_data["media_id"] = media_id
return self.send(
user_id,
open_kfid,
msgid,
msg={"msgtype": "video", "video": dict(video_data)},
)
def send_file(self, user_id, open_kfid, media_id, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={"msgtype": "file", "file": {"media_id": media_id}},
)
def send_articles_link(self, user_id, open_kfid, article, msgid=""):
articles_data = {
"title": article["title"],
"desc": article["desc"],
"url": article["url"],
"thumb_media_id": article["thumb_media_id"],
}
return self.send(
user_id,
open_kfid,
msgid,
msg={"msgtype": "news", "link": {"link": articles_data}},
)
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "msgmenu",
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
},
)
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "location",
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
},
)
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
return self.send(
user_id,
open_kfid,
msgid,
msg={
"msgtype": "miniprogram",
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
},
)

View File

@@ -0,0 +1,286 @@
import sys
import uuid
import asyncio
import quart
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
PlatformMetadata,
MessageType,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
from astrbot.core import logger
from requests import Response
from wechatpy.utils import check_signature
from wechatpy.crypto import WeChatCrypto
from wechatpy import WeChatClient
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage
from wechatpy.exceptions import InvalidSignatureException
from wechatpy import parse_message
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key")
self.appid = config.get("appid")
self.server.add_url_rule(
"/callback/command", view_func=self.verify, methods=["GET"]
)
self.server.add_url_rule(
"/callback/command", view_func=self.callback_command, methods=["POST"]
)
self.crypto = WeChatCrypto(self.token, self.encoding_aes_key, self.appid)
self.event_queue = event_queue
self.callback = None
self.shutdown_event = asyncio.Event()
async def verify(self):
logger.info(f"验证请求有效性: {quart.request.args}")
args = quart.request.args
if not args.get("signature", None):
logger.error("未知的响应,请检查回调地址是否填写正确。")
return "err"
try:
check_signature(
self.token,
args.get("signature"),
args.get("timestamp"),
args.get("nonce"),
)
logger.info("验证请求有效性成功。")
return args.get("echostr", "empty")
except InvalidSignatureException:
logger.error("验证请求有效性失败,签名异常,请检查配置。")
return "err"
async def callback_command(self):
data = await quart.request.get_data()
msg_signature = quart.request.args.get("msg_signature")
timestamp = quart.request.args.get("timestamp")
nonce = quart.request.args.get("nonce")
try:
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
except InvalidSignatureException:
logger.error("解密失败,签名异常,请检查配置。")
raise
else:
msg = parse_message(xml)
logger.info(f"解析成功: {msg}")
if self.callback:
result_xml = await self.callback(msg)
if not result_xml:
return "success"
if isinstance(result_xml, str):
return result_xml
return "success"
async def start_polling(self):
logger.info(
f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。"
)
await self.server.run_task(
host=self.callback_server_host,
port=self.port,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger(self):
await self.shutdown_event.wait()
@register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
class WeixinOfficialAccountPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
self.api_base_url = platform_config.get(
"api_base_url", "https://api.weixin.qq.com/cgi-bin/"
)
self.active_send_mode = self.config.get("active_send_mode", False)
if not self.api_base_url:
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
if self.api_base_url.endswith("/"):
self.api_base_url = self.api_base_url[:-1]
if not self.api_base_url.endswith("/cgi-bin"):
self.api_base_url += "/cgi-bin"
if not self.api_base_url.endswith("/"):
self.api_base_url += "/"
self.server = WecomServer(self._event_queue, self.config)
self.client = WeChatClient(
self.config["appid"].strip(),
self.config["secret"].strip(),
)
self.client.API_BASE_URL = self.api_base_url
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future
self.wexin_event_workers: dict[str, asyncio.Future] = {}
async def callback(msg: BaseMessage):
try:
if self.active_send_mode:
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError:
pass
except Exception as e:
logger.error(f"转换消息时出现异常: {e}")
self.server.callback = callback
@override
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"weixin_official_account",
"微信公众平台 适配器",
)
@override
async def run(self):
await self.server.start_polling()
async def convert_message(
self, msg, future: asyncio.Future = None
) -> AstrBotMessage | None:
abm = AstrBotMessage()
if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.self_id = str(msg.target)
abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.session_id = abm.sender.user_id
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]"
abm.self_id = str(msg.target)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.session_id = abm.sender.user_id
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_event_loop().run_in_executor(
None, self.client.media.download, msg.media_id
)
path = f"data/temp/wecom_{msg.media_id}.amr"
with open(path, "wb") as f:
f.write(resp.content)
try:
from pydub import AudioSegment
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
audio = AudioSegment.from_file(path)
audio.export(path_wav, format="wav")
except Exception as e:
logger.error(
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。"
)
path_wav = path
return
abm.message_str = ""
abm.self_id = str(msg.target)
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.session_id = abm.sender.user_id
else:
logger.warning(f"暂未实现的事件: {msg.type}")
future.set_result(None)
return
# 很不优雅 :(
abm.raw_message = {
"message": msg,
"future": future,
"active_send_mode": self.active_send_mode,
}
logger.info(f"abm: {abm}")
await self.handle_msg(abm)
async def handle_msg(self, message: AstrBotMessage):
message_event = WeixinOfficialAccountPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
)
self.commit_event(message_event)
def get_client(self) -> WeChatClient:
return self.client
async def terminate(self):
self.server.shutdown_event.set()
try:
await self.server.server.shutdown()
except Exception as _:
pass
logger.info("微信公众平台 适配器已被优雅地关闭")

View File

@@ -0,0 +1,185 @@
import uuid
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from wechatpy import WeChatClient
from wechatpy.replies import TextReply, ImageReply, VoiceReply
from astrbot.api import logger
try:
import pydub
except Exception:
logger.warning(
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。"
)
pass
class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: WeChatClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(
client: WeChatClient, message: MessageChain, user_name: str
):
pass
async def split_plain(self, plain: str) -> list[str]:
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
Args:
plain (str): 要分割的长文本
Returns:
list[str]: 分割后的文本列表
"""
if len(plain) <= 2048:
return [plain]
else:
result = []
start = 0
while start < len(plain):
# 剩下的字符串长度<2048时结束
if start + 2048 >= len(plain):
result.append(plain[start:])
break
# 向前搜索分割标点符号
end = min(start + 2048, len(plain))
cut_position = end
for i in range(end, start, -1):
if i < len(plain) and plain[i - 1] in [
"",
"",
"",
".",
"!",
"?",
"\n",
";",
"",
]:
cut_position = i
break
# 没找到合适的位置分割, 直接切分
if cut_position == end and end < len(plain):
cut_position = end
result.append(plain[start:cut_position])
start = cut_position
return result
async def send(self, message: MessageChain):
message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
for comp in message.chain:
if isinstance(comp, Plain):
# Split long text messages if needed
plain_chunks = await self.split_plain(comp.text)
for chunk in plain_chunks:
if active_send_mode:
self.client.message.send_text(message_obj.sender.user_id, chunk)
else:
reply = TextReply(
content=chunk,
message=self.message_obj.raw_message["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
with open(img_path, "rb") as f:
try:
response = self.client.media.upload("image", f)
except Exception as e:
logger.error(f"微信公众平台上传图片失败: {e}")
await self.send(
MessageChain().message(f"微信公众平台上传图片失败: {e}")
)
return
logger.debug(f"微信公众平台上传图片返回: {response}")
if active_send_mode:
self.client.message.send_image(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = ImageReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
# 转成amr
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
pydub.AudioSegment.from_wav(record_path).export(
record_path_amr, format="amr"
)
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信公众平台上传语音失败: {e}")
await self.send(
MessageChain().message(f"微信公众平台上传语音失败: {e}")
)
return
logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = VoiceReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
else:
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)

View File

@@ -1,5 +1,5 @@
from .provider import Provider, Personality, STTProvider from .provider import Provider, Personality, STTProvider
from .entites import ProviderMetaData from .entities import ProviderMetaData
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"] __all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]

View File

@@ -1,67 +1,19 @@
import enum from astrbot.core.provider.entities import (
from dataclasses import dataclass, field ProviderRequest,
from typing import List, Dict, Type ProviderType,
from .func_tool_manager import FuncCall ProviderMetaData,
from openai.types.chat.chat_completion import ChatCompletion ToolCallsResult,
from astrbot.core.db.po import Conversation AssistantMessageSegment,
ToolCallMessageSegment,
LLMResponse,
)
__all__ = [
class ProviderType(enum.Enum): "ProviderRequest",
CHAT_COMPLETION = "chat_completion" "ProviderType",
SPEECH_TO_TEXT = "speech_to_text" "ProviderMetaData",
TEXT_TO_SPEECH = "text_to_speech" "ToolCallsResult",
"AssistantMessageSegment",
"ToolCallMessageSegment",
@dataclass "LLMResponse",
class ProviderMetaData: ]
type: str
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""提供商适配器描述."""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
default_config_tmpl: dict = None
"""平台的默认配置模板"""
provider_display_name: str = None
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
class ProviderRequest:
prompt: str
"""提示词"""
session_id: str = ""
"""会话 ID"""
image_urls: List[str] = None
"""图片 URL 列表"""
func_tool: FuncCall = None
"""工具"""
contexts: List = None
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation = None
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt.strip()})"
def __str__(self):
return self.__repr__()
@dataclass
class LLMResponse:
role: str
"""角色, assistant, tool, err"""
completion_text: str = ""
"""LLM 返回的文本"""
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None

View File

@@ -0,0 +1,284 @@
import enum
import base64
import json
from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
EMBEDDING = "embedding"
@dataclass
class ProviderMetaData:
type: str
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""提供商适配器描述."""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
default_config_tmpl: dict = None
"""平台的默认配置模板"""
provider_display_name: str = None
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
class ToolCallMessageSegment:
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
tool_call_id: str
content: str
role: str = "tool"
def to_dict(self):
return {
"tool_call_id": self.tool_call_id,
"content": self.content,
"role": self.role,
}
@dataclass
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
role: str = "assistant"
def to_dict(self):
ret = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
elif self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
@dataclass
class ToolCallsResult:
"""工具调用结果"""
tool_calls_info: AssistantMessageSegment
"""函数调用的信息"""
tool_calls_result: List[ToolCallMessageSegment]
"""函数调用的结果"""
def to_openai_messages(self) -> List[Dict]:
ret = [
self.tool_calls_info.to_dict(),
*[item.to_dict() for item in self.tool_calls_result],
]
return ret
@dataclass
class ProviderRequest:
prompt: str
"""提示词"""
session_id: str = ""
"""会话 ID"""
image_urls: List[str] = None
"""图片 URL 列表"""
func_tool: FuncCall = None
"""可用的函数工具"""
contexts: List = None
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation = None
tool_calls_result: ToolCallsResult = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self):
return self.__repr__()
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts:
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
result_parts = []
for ctx in self.contexts:
role = ctx.get("role", "unknown")
content = ctx.get("content", "")
if isinstance(content, str):
result_parts.append(f"{role}: {content}")
elif isinstance(content, list):
msg_parts = []
image_count = 0
for item in content:
item_type = item.get("type", "")
if item_type == "text":
msg_parts.append(item.get("text", ""))
elif item_type == "image_url":
image_count += 1
if image_count > 0:
if msg_parts:
msg_parts.append(f"[+{image_count} images]")
else:
msg_parts.append(f"[{image_count} images]")
result_parts.append(f"{role}: {''.join(msg_parts)}")
return result_parts
async def assemble_context(self) -> Dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
if self.image_urls:
user_content = {
"role": "user",
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self._encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self._encode_image_bs64(image_path)
else:
image_data = await self._encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
)
return user_content
else:
return {"role": "user", "content": self.prompt}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
@dataclass
class LLMResponse:
role: str
"""角色, assistant, tool, err"""
result_chain: MessageChain = None
"""返回的消息链"""
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
tools_call_ids: List[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
_completion_text: str = ""
is_chunk: bool = False
"""是否是流式输出的单个 Chunk"""
def __init__(
self,
role: str,
completion_text: str = "",
result_chain: MessageChain = None,
tools_call_args: List[Dict[str, any]] = None,
tools_call_name: List[str] = None,
tools_call_ids: List[str] = None,
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
Args:
role (str): 角色, assistant, tool, err
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
if tools_call_args is None:
tools_call_args = []
if tools_call_name is None:
tools_call_name = []
if tools_call_ids is None:
tools_call_ids = []
self.role = role
self.completion_text = completion_text
self.result_chain = result_chain
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
self.is_chunk = is_chunk
@property
def completion_text(self):
if self.result_chain:
return self.result_chain.get_plain_text()
return self._completion_text
@completion_text.setter
def completion_text(self, value):
if self.result_chain:
self.result_chain.chain = [
comp
for comp in self.result_chain.chain
if not isinstance(comp, Comp.Plain)
] # 清空 Plain 组件
self.result_chain.chain.insert(0, Comp.Plain(value))
else:
self._completion_text = value
def to_openai_tool_calls(self) -> List[Dict]:
"""将工具调用信息转换为 OpenAI 格式"""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
{
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
}
)
return ret

View File

@@ -1,7 +1,42 @@
from __future__ import annotations
import json import json
import textwrap import textwrap
from typing import Dict, List, Awaitable import os
import asyncio
import logging
from datetime import timedelta
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core.utils.log_pipe import LogPipe
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
try:
import mcp
from mcp.client.sse import sse_client
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
@dataclass @dataclass
@@ -13,28 +48,162 @@ class FuncTool:
name: str name: str
parameters: Dict parameters: Dict
description: str description: str
handler: Awaitable handler: Awaitable = None
handler_module_path: str = None # 必须要保留这个handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools """处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True active: bool = True
"""是否激活""" """是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self): def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}), active={self.active})" return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def execute(self, **args) -> Any:
"""执行函数调用"""
if self.origin == "local":
if not self.handler:
raise Exception(f"Local function {self.name} has no handler")
return await self.handler(**args)
elif self.origin == "mcp":
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
if ":" in self.name:
# 如果名字是格式为 mcp:server:tool_name提取实际的工具名
actual_tool_name = self.name.split(":")[-1]
return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
return await self.mcp_client.session.call_tool(self.name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
SUPPORTED_TYPES = [ class MCPClient:
"string", def __init__(self):
"number", # Initialize session and client objects
"object", self.session: Optional[mcp.ClientSession] = None
"array", self.exit_stack = AsyncExitStack()
"boolean",
] # json schema 支持的数据类型 self.name = None
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = []
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在:
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = mcp_server_config.copy()
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
key_0 = list(cfg["mcpServers"].keys())[0]
cfg = cfg["mcpServers"][key_0]
cfg.pop("active", None) # Remove active flag from config
if "url" in cfg:
is_sse = True
if cfg.get("transport") == "streamable_http":
is_sse = False
if is_sse:
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self._streams_context.__aenter__()
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*streams)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self._streams_context.__aenter__()
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
)
else:
server_params = mcp.StdioServerParameters(
**cfg,
)
def callback(msg: str):
# 处理 MCP 服务的错误日志
self.server_errlogs.append(msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
),
),
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
class FuncCall: class FuncCall:
def __init__(self) -> None: def __init__(self) -> None:
self.func_list: List[FuncTool] = [] self.func_list: List[FuncTool] = []
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_service_queue = asyncio.Queue()
"""用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool: def empty(self) -> bool:
return len(self.func_list) == 0 return len(self.func_list) == 0
@@ -46,14 +215,16 @@ class FuncCall:
desc: str, desc: str,
handler: Awaitable, handler: Awaitable,
) -> None: ) -> None:
""" """添加函数调用工具
为函数调用function-calling / tools-use添加工具。
@param name: 函数名 @param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述 @param desc: 函数描述
@param func_obj: 处理函数 @param func_obj: 处理函数
""" """
# check if the tool has been added before
self.remove_func(name)
params = { params = {
"type": "object", # hard-coded here "type": "object", # hard-coded here
"properties": {}, "properties": {},
@@ -70,13 +241,14 @@ class FuncCall:
handler=handler, handler=handler,
) )
self.func_list.append(_func) self.func_list.append(_func)
logger.info(f"添加函数调用工具: {name}")
def remove_func(self, name: str) -> None: def remove_func(self, name: str) -> None:
""" """
删除一个函数调用工具。 删除一个函数调用工具。
""" """
for i, f in enumerate(self.func_list): for i, f in enumerate(self.func_list):
if f["name"] == name: if f.name == name:
self.func_list.pop(i) self.func_list.pop(i)
break break
@@ -86,24 +258,195 @@ class FuncCall:
return f return f
return None return None
def get_func_desc_openai_style(self) -> list: async def _init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
"mcpServers": {
"weather": {
"command": "uv",
"args": [
"--directory",
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
"run",
"weather.py"
]
}
}
...
}
```
"""
data_dir = get_astrbot_data_path()
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
with open(mcp_json_file, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return
mcp_server_json_obj: Dict[str, Dict] = json.load(
open(mcp_json_file, "r", encoding="utf-8")
)["mcpServers"]
for name in mcp_server_json_obj.keys():
cfg = mcp_server_json_obj[name]
if cfg.get("active", True):
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event)
)
self.mcp_client_event[name] = event
async def mcp_service_selector(self):
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
{"type": "init"} 初始化所有MCP客户端
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
{"type": "terminate"} 终止所有MCP客户端
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
"""
while True:
data = await self.mcp_service_queue.get()
if data["type"] == "init":
if "name" in data:
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(
data["name"], data["cfg"], event
)
)
self.mcp_client_event[data["name"]] = event
else:
await self._init_mcp_clients()
elif data["type"] == "terminate":
if "name" in data:
# await self._terminate_mcp_client(data["name"])
if data["name"] in self.mcp_client_event:
self.mcp_client_event[data["name"]].set()
self.mcp_client_event.pop(data["name"], None)
self.func_list = [
f
for f in self.func_list
if not (
f.origin == "mcp" and f.mcp_server_name == data["name"]
)
]
else:
for name in self.mcp_client_dict.keys():
# await self._terminate_mcp_client(name)
# self.mcp_client_event[name].set()
if name in self.mcp_client_event:
self.mcp_client_event[name].set()
self.mcp_client_event.pop(name, None)
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
async def _init_mcp_client_task_wrapper(
self, name: str, cfg: dict, event: asyncio.Event
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
await self._terminate_mcp_client(name)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
try:
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
mcp_client = MCPClient()
mcp_client.name = name
self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
# 移除该MCP服务之前的工具如有
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
]
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
func_tool = FuncTool(
name=tool.name,
parameters=tool.inputSchema,
description=tool.description,
origin="mcp",
mcp_server_name=name,
mcp_client=mcp_client,
)
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
# 发生错误时确保客户端被清理
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
return
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
if name in self.mcp_client_dict:
try:
# 关闭MCP连接
await self.mcp_client_dict[name].cleanup()
del self.mcp_client_dict[name]
except Exception as e:
logger.info(f"清空 MCP 客户端资源 {name}: {e}")
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
""" """
获得 OpenAI API 风格的**已经激活**的工具描述 获得 OpenAI API 风格的**已经激活**的工具描述
""" """
_l = [] _l = []
# 处理所有工具包括本地和MCP工具
for f in self.func_list: for f in self.func_list:
if not f.active: if not f.active:
continue continue
_l.append( func_ = {
{ "type": "function",
"type": "function", "function": {
"function": { "name": f.name,
"name": f.name, # "parameters": f.parameters,
"parameters": f.parameters, "description": f.description,
"description": f.description, },
}, }
} func_["function"]["parameters"] = f.parameters
) if not f.parameters.get("properties") and omit_empty_parameter_field:
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True则删除 parameters 字段
del func_["function"]["parameters"]
_l.append(func_)
return _l return _l
def get_func_desc_anthropic_style(self) -> list: def get_func_desc_anthropic_style(self) -> list:
@@ -129,22 +472,86 @@ class FuncCall:
tools.append(tool) tools.append(tool)
return tools return tools
def get_func_desc_google_genai_style(self) -> Dict: def get_func_desc_google_genai_style(self) -> dict:
"""
获得 Google GenAI API 风格的**已经激活**的工具描述
"""
# Gemini API 支持的数据类型和格式
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
def convert_schema(schema: dict) -> dict:
"""转换 schema 为 Gemini API 格式"""
# 如果 schema 包含 anyOf则只返回 anyOf 字段
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
# 暂时指定默认为null
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties: # 只在有非空属性时添加
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": f.name,
"description": f.description,
**({"parameters": convert_schema(f.parameters)}),
}
for f in self.func_list
if f.active
]
declarations = {} declarations = {}
tools = []
for f in self.func_list:
if not f.active:
continue
func_declaration = {"name": f.name, "description": f.description}
# 检查并添加非空的properties参数
params = f.parameters if isinstance(f.parameters, dict) else {}
if params.get("properties", {}):
func_declaration["parameters"] = params
tools.append(func_declaration)
if tools: if tools:
declarations["function_declarations"] = tools declarations["function_declarations"] = tools
return declarations return declarations
@@ -156,9 +563,9 @@ class FuncCall:
continue continue
_l.append( _l.append(
{ {
"name": f["name"], "name": f.name,
"parameters": f["parameters"], "parameters": f.parameters,
"description": f["description"], "description": f.description,
} }
) )
func_definition = json.dumps(_l, ensure_ascii=False) func_definition = json.dumps(_l, ensure_ascii=False)
@@ -208,14 +615,11 @@ class FuncCall:
func_name = tool["name"] func_name = tool["name"]
args = tool["args"] args = tool["args"]
# 调用函数 # 调用函数
tool_callable = None func_tool = self.get_func(func_name)
for func in self.func_list: if not func_tool:
if func.name == func_name:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
raise Exception(f"Request function {func_name} not found.") raise Exception(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
ret = await func_tool.execute(**args)
if ret: if ret:
tool_call_result.append(str(ret)) tool_call_result.append(str(ret))
return tool_call_result, True return tool_call_result, True
@@ -225,3 +629,8 @@ class FuncCall:
def __repr__(self): def __repr__(self):
return str(self.func_list) return str(self.func_list)
async def terminate(self):
for name in self.mcp_client_dict.keys():
await self._terminate_mcp_client(name)
logger.debug(f"清理 MCP 客户端 {name} 资源")

Some files were not shown because too many files have changed in this diff Show More