Compare commits

...

424 Commits

Author SHA1 Message Date
Soulter
3b54a24037 feat: 初步实现可视化编辑 object 配置项 2025-02-22 17:12:35 +08:00
Soulter
2807e1e892 feat: add template of FastGPT 2025-02-22 15:43:14 +08:00
Soulter
0a2abd8214 Merge pull request #572 from Soulter/feat-dashscope
支持阿里云百炼应用智能体、工作流
2025-02-22 15:04:46 +08:00
Soulter
8beb7acdb1 feat: 支持为 dify 和 dashscope 提供商设置默认固定变量 #552 2025-02-22 14:48:18 +08:00
Soulter
466c80b94d feat: 阿里云百炼应用工作流支持自定义动态变量 #552 2025-02-22 14:32:37 +08:00
Soulter
36c0cfc9a9 feat: 支持阿里云百炼应用智能体、工作流
#552
2025-02-22 14:08:51 +08:00
Soulter
35ba1b3345 fix: gewechat verify code 2025-02-22 11:37:34 +08:00
Soulter
d00821d1c7 Update README.md 2025-02-22 10:07:18 +08:00
Soulter
6c1b3f242b Merge pull request #568 from Raven95676/master
🐛 fix: 修复webchat未处理base64的问题
2025-02-22 01:07:20 +08:00
Raven95676
9f9da1e0c9 🐛 fix: 修复webchat未处理base64的问题 2025-02-21 23:39:53 +08:00
崔永亮
14fb4b70bd feat: 支持 gewechat 设置验证码 #448 2025-02-21 23:08:23 +08:00
崔永亮
b1049540a4 feat: claude 支持纯图片 2025-02-21 22:26:31 +08:00
Fridemn
5e2909df33 Merge pull request #559 from Rt39/feat-claude-api
添加对Anthropic Claude API的支持
2025-02-21 21:12:52 +08:00
崔永亮
c122dad21f feat: 添加自定义api base 2025-02-21 21:07:59 +08:00
Rt39
48ae686602 feat: add claude template 2025-02-20 23:58:10 -05:00
Rt39
bf2c3a1a81 fix: 根据Codacy Production / Codacy Static Code Analysis修改格式问题 2025-02-20 21:15:07 -05:00
Rt39
96e7a93886 feat: 添加对Claude API的支持 2025-02-20 19:59:16 -05:00
Soulter
dba1ed1e19 v3.4.30 2025-02-21 01:31:36 +08:00
Soulter
a24514876b fix: 修复 dify 无法使用事件钩子的问题以及出现 GeneratorExit 的问题 #533 #264 2025-02-21 01:14:13 +08:00
Soulter
466a1c1c41 🐛 fix: 修复某些情况下导致插件报错 AttributeError 的问题 #549 2025-02-21 00:38:08 +08:00
Soulter
a2d5e9f40f feat: add xAI template 2025-02-20 16:34:32 +08:00
Soulter
1bbff1d161 v3.4.29 2025-02-19 20:05:33 +08:00
Soulter
0948bae99b feat: 添加代码执行器 Docker 宿主机绝对路径配置及相关功能
Co-authored-by: Bocity <haolovej@vip.qq.com>
2025-02-19 19:56:31 +08:00
Soulter
850db41596 feat: gemini source 初步支持对 API Key 进行负载均衡请求 #534 2025-02-19 19:06:37 +08:00
Soulter
7bafc87e2b 🐛 fix: 修复部分单指令失效的问题 2025-02-19 19:04:23 +08:00
Soulter
1a0de02a15 fix: 尝试修复gewechat群聊用户名出现unknown 2025-02-19 17:07:11 +08:00
Soulter
6d5d278624 fix: 尝试修复 gewechat 微信群聊情况下可能导致 unknown 的问题 #537 2025-02-19 16:42:30 +08:00
Soulter
3b4cc48fa0 👌 perf: 开启对话隔离的群聊以及私聊下,非op可以可以使用 /del 和 /reset #519 2025-02-19 16:22:42 +08:00
Soulter
c908461088 Merge pull request #543 from Soulter/refactor-command-group
更换为预编译指令的方式处理指令组指令并且让事件钩子也支持 yield 的方式发送消息
2025-02-19 15:54:26 +08:00
Soulter
53d1398d30 fix: 修复子指令组不能被调用的问题 2025-02-19 15:53:01 +08:00
Soulter
782c0367d0 feat: 事件钩子支持 yield 方式发送消息 2025-02-19 15:29:10 +08:00
Soulter
4678222e9b 👌 refactor: 更换为预编译指令的方式处理指令组指令 2025-02-19 14:55:14 +08:00
Soulter
f71dc3e4be 🐛 fix: reminder time zone issue 2025-02-19 00:15:14 +08:00
Soulter
f6233893bd 🐛 fix: 修复 reminder rm失败 #529 2025-02-19 00:10:18 +08:00
Soulter
6427bcf130 👌perf: 查询模型列表时,可以显示当前使用的模型名称 #523 2025-02-17 22:35:45 +08:00
Soulter
8fa41b706c Merge pull request #522 from yuanxinlyx/fix-keyerror-ls-command
fix: resolve KeyError when current conversation is not in paginated list
2025-02-17 21:45:40 +08:00
YuanxinLu
4706c4438d fix: resolve KeyError when current conversation is not in paginated list 2025-02-17 03:15:59 +08:00
Soulter
0c8ebc2b06 chore: clean up 2025-02-16 16:52:13 +08:00
Soulter
b3b5ebc2ca v3.4.28 2025-02-16 16:19:03 +08:00
Soulter
b8aa23ccc5 🐛fix: 修复转发消息的字数阈值功能#510 2025-02-16 15:54:29 +08:00
Soulter
364843db29 Merge pull request #389 from Nothingness-Void/新增过滤掉正则表达式内容
新增过滤掉正则表达式内容
2025-02-16 15:28:51 +08:00
Soulter
aa56c8f7e6 Merge branch 'master' into 新增过滤掉正则表达式内容 2025-02-16 15:27:30 +08:00
Soulter
8e9fd27058 merge branch master 2025-02-16 15:17:44 +08:00
Soulter
b75908cb2a Merge pull request #517 from Cvandia/master
 feat: 添加命令和命令组的别名支持
2025-02-16 14:51:47 +08:00
Soulter
af6df49ce1 perf: 补充别名为可选参数以前向兼容 2025-02-16 14:50:49 +08:00
Cvandia
bd3bdb5769 feat: 添加命令和命令组的别名支持 2025-02-16 14:44:17 +08:00
Soulter
98fe193b21 Merge pull request #477 from AraragiEro/master
[Feature] 希望添加更为灵活的filter.permission_type使用方式,使用户能自定义权限类型
2025-02-16 13:53:07 +08:00
Soulter
26cbc9e8b1 chore: cleanup 2025-02-16 13:32:28 +08:00
Alero
ebb8c43fd0 bug: 尝试修复cleancode错误 2025-02-16 10:56:17 +08:00
Soulter
8c7344f1c4 👌perf(qq): supports to pass OneBot notice, request event 2025-02-16 01:04:08 +08:00
Soulter
5c32a17787 👌perf: 优化了分段回复和回复时at,引用都打开时的一些体验性问题 2025-02-15 19:29:34 +08:00
Soulter
aff520e69a fix: 修复 Dify 下无法主动回复的问题 #494 2025-02-15 18:31:21 +08:00
Alero
45e627c33c fix: a bug when add filter to root command group 2025-02-14 23:52:31 +08:00
Alero
7a1b158f83 fix: cleancode err 2025-02-14 22:46:22 +08:00
Alero
6374c5d49d fix: add & | operation to customfilter 2025-02-14 22:33:32 +08:00
Alero
fd460b19d4 fix: cleancode err 2025-02-14 20:43:54 +08:00
Alero
dff7cc4ca5 feat: when custom filter cant pass, won't raise error anymore.
and when you use a command group and dont have custom filter access, the return group tree wont contain the command that you dont have permisson.
2025-02-14 20:34:31 +08:00
Alero
d013320bec feat: more powerful CustomFilter 2025-02-14 19:15:19 +08:00
Soulter
fc6dcfaf21 🐛 fix: cannot search plugin 2025-02-14 18:45:56 +08:00
Soulter
a001270bd2 feat: webui supports to search plugin via name 2025-02-14 18:43:04 +08:00
Soulter
9e67883fbd 🐛 fix: add no_proxy env vars to support localhost requests, fix 502 error when use ollama #504 2025-02-14 16:51:02 +08:00
Soulter
f1a448708c 🐛 fix: segmented reply caused incomplete non-llm-response #503 2025-02-14 16:19:09 +08:00
Soulter
a4bfa96502 feat: 支持自定义 Dify 工作流文本输入变量名 #441 2025-02-14 15:41:02 +08:00
Soulter
595b83a256 🐛 FIX: cannot send file in private chat when turn on the reply with quote #262 2025-02-14 14:41:41 +08:00
Soulter
8d34f77321 v3.4.27 2025-02-14 01:53:26 +08:00
Soulter
67095f97b1 🐛 fix: delete conversation
 feat: supports active reply whitelist
2025-02-14 01:43:52 +08:00
Soulter
50740c94ab 🐛 fix: cannot input text before mention in gewechat #492 2025-02-14 01:09:48 +08:00
Soulter
4db4cfeda2 👌 perf: format datetime labels in MessageStat component #460 2025-02-14 00:30:34 +08:00
Soulter
ad13cef89c 👌perf: sort models by id when listing models #384 2025-02-14 00:08:12 +08:00
Soulter
855fc6fcd1 Display the Japanese translation entry 2025-02-13 23:36:50 +08:00
Soulter
8f12244e51 Merge pull request #491 from eltociear/add-japanese-readme
docs: add Japanese README
2025-02-13 22:56:21 +08:00
Ikko Eltociear Ashimine
fe0213465c docs: add Japanese README
I created Japanese translated README.
2025-02-13 14:45:52 +09:00
Soulter
f984047004 fix: unable to send c2c message using webhook qqofficial platform #484 2025-02-13 00:01:16 +08:00
Soulter
19e9e2d090 fix: fix dify cannot set/unset variables #482 2025-02-12 23:58:04 +08:00
Soulter
7fe3b97d00 fix: improve content safety check handling for at or wake commands 2025-02-12 23:42:32 +08:00
Soulter
9cd243da47 fix: handle empty content in gemini context 2025-02-12 23:39:41 +08:00
Soulter
e43208c2e9 fix: update session_id assignment logic for group messages 2025-02-12 14:04:55 +08:00
Soulter
dc016fc22f feat: update validate_config to return a tuple contains casted data 2025-02-12 13:50:24 +08:00
Alero
c6f037cae2 fix: a undefine mistake 2025-02-12 03:25:01 +08:00
Alero
f049830e28 Merge branch 'master' of github.com:AraragiEro/AstrBot 2025-02-12 03:06:23 +08:00
Alero
dd1995ae0b feat: add a way to define custom permission filter. 2025-02-12 03:05:51 +08:00
Soulter
23dc233569 chore: remove useless config items 2025-02-12 02:32:57 +08:00
Soulter
0977aa7d0d chore: fix the default port of qo webhook 2025-02-12 02:28:15 +08:00
Soulter
24862b0672 docs: update the comments of register_llm_tool 2025-02-12 02:27:39 +08:00
Soulter
f05a57efc3 chore: v3.4.26 2025-02-12 01:55:36 +08:00
Soulter
65331a9d7c feat: 支持基于对数函数的分段回复延时时间计算 2025-02-12 01:44:08 +08:00
Soulter
f7ae287e40 fix: ensure result is retrieved again to handle potential plugin chain replacements 2025-02-12 00:27:25 +08:00
Soulter
45f380b1f6 feat: add configuable port for dashboard and improve the method of getting local ip address 2025-02-11 23:00:24 +08:00
Soulter
9e6b329df4 Merge pull request #472 from Akuma-real/master
fix: correct dashboard update tooltip typo
2025-02-11 22:04:19 +08:00
Soulter
43cd34d94c feat: supports to check the content safety of LLM output #474 2025-02-11 22:03:44 +08:00
Soulter
9fa00aff9a 支持完善的 Dify Chat 模式对话管理 2025-02-11 21:30:17 +08:00
Soulter
9a56dcb1be fix: cannot reset conversation in dify chat mode #469 2025-02-11 21:29:28 +08:00
鬼鬼Sama
fdfe7bbe59 fix: correct dashboard update tooltip typo 2025-02-11 20:16:09 +08:00
Soulter
3a99a60792 perf: gewechat send all events to pipeline 2025-02-11 20:00:39 +08:00
Soulter
fa2b4e14df fix: gewechat cannot send message directly 2025-02-11 19:49:20 +08:00
Soulter
35322a6900 Merge pull request #465 from Soulter/feat-qo-webhook
支持 Webhook 方式接入 QQ 官方机器人平台
2025-02-11 18:10:14 +08:00
Soulter
2ccf29d61e Update README.md 2025-02-11 17:28:03 +08:00
Soulter
b068013343 perf: better handle in qq official send 2025-02-11 01:25:17 -05:00
Soulter
d839e72998 feat: 支持 Webhook 方式接入 QQ 官方机器人接口 2025-02-11 01:18:25 -05:00
Soulter
d7c9a8ed29 chore: webhook server, client 2025-02-11 11:19:50 +08:00
Soulter
6837d4d692 chore: update version 2025-02-11 02:05:06 +08:00
Soulter
8aba83735b Update README.md 2025-02-11 01:31:31 +08:00
Soulter
aa51187747 perf(core): change log level to debug for platform and provider adapter instantiation 2025-02-11 01:25:52 +08:00
Soulter
5f07a9ae95 perf(core): better handle in loading platforms 2025-02-11 01:23:50 +08:00
Soulter
a2ca767bf4 v3.4.25 2025-02-11 01:12:23 +08:00
Soulter
5806c74e7c chore(core): display the unsupported message segments 2025-02-11 01:10:17 +08:00
Soulter
0481e1d45e fix(core): github mirror not applied successfully 2025-02-11 01:10:17 +08:00
Soulter
3177b61421 feat(platform): support lark platform 2025-02-11 01:07:14 +08:00
Soulter
6009cf5dfa feat: 添加 moonshot 配置模板 #446 2025-02-10 18:54:59 +08:00
Soulter
0a970e8c31 feat: 支持gewechat文件输出 2025-02-10 18:46:54 +08:00
Soulter
aa276ca6af fix: 修复gewechat无法at人和发语音失败的问题 #447 #438 2025-02-10 18:11:22 +08:00
Soulter
9f02dd13ff fix: 修复qq在@和回复开启的情况下转发消息异常的问题 2025-02-10 13:07:09 +08:00
Soulter
609e723322 v3.4.24 2025-02-10 00:34:02 +08:00
Soulter
c564a1d53e fix: raw_completion 没有正确传递 #439 2025-02-10 00:26:53 +08:00
Soulter
a7fe31f28b fix: 修复指令不经过唤醒前缀也能生效的问题。在引用消息的时候无法使用前缀唤醒机器人 #444 2025-02-09 22:35:52 +08:00
Soulter
a84dc599d6 fix: 修复 /tts 指令 2025-02-09 22:14:10 +08:00
Soulter
8da029add9 feat: 支持 TTS, STT 提供商的显示和快捷切换 2025-02-09 22:08:51 +08:00
Soulter
ba45a2d270 feat: 支持设置GitHub反向代理地址 2025-02-09 18:51:53 +08:00
Soulter
cb56b22aea Update README.md 2025-02-09 16:49:00 +08:00
Soulter
23cc5b31ba perf: 从压缩包上传插件时,去除branch尾缀 2025-02-09 14:59:27 +08:00
Soulter
e8d99f0460 fix: 修复戳一戳消息报错 2025-02-09 13:57:33 +08:00
Soulter
6bcd10cd5c fix: gemini 报错时显示 apikey 2025-02-09 13:56:55 +08:00
Soulter
619fb20c5f fix: drun 不支持函数调用的报错 2025-02-09 01:20:11 +08:00
Soulter
386a312e96 fix: 修复一些typo 2025-02-08 22:52:24 +08:00
Soulter
2759d347e6 update: add socksio, echatpy, cryptography to dockerfile 2025-02-08 22:10:17 +08:00
Soulter
b6ec327b49 perf:完善主动会话 2025-02-08 22:04:36 +08:00
Soulter
ee02d622ba v3.4.23 2025-02-08 21:42:37 +08:00
Soulter
5c4a6083f5 Merge pull request #433 from Cvandia/master
支持 fishaudio tts 文字转语音
2025-02-08 21:20:03 +08:00
Soulter
49e63a3d3d perf: 优化报错显示 2025-02-08 21:19:25 +08:00
Soulter
6bae9dc9ed 👌 perf: 当响应头不为audio/wav时抛出报错 2025-02-08 21:16:09 +08:00
Cvandia
5fa1979a46 🐛 fix: 移除调试过程的不必要的文件写入操作 2025-02-08 20:49:37 +08:00
Cvandia
b40d4fa315 Merge remote-tracking branch 'upstream/master' 2025-02-08 20:45:49 +08:00
Soulter
4d2ff7cd5b fix: 修复 qq 回复别人的时候也会触发机器人, Onebot at 使用 string #330 2025-02-08 20:35:10 +08:00
Cvandia
d8ec0e64d0 Merge remote-tracking branch 'upstream/master' 2025-02-08 19:40:56 +08:00
Cvandia
82e979cc07 feat: 添加 FishAudio TTS API 支持,更新配置和依赖项 2025-02-08 19:37:43 +08:00
Soulter
8c132a51f5 fix: 修复子指令设置permission之后会导致其一定会被执行 #427 2025-02-08 18:51:30 +08:00
Soulter
40bd372cc1 fix: 重启gewe的时候机器人会疯狂发消息 #421 2025-02-08 18:02:42 +08:00
Soulter
212e114270 perf: 优化了一些提示 2025-02-08 15:55:46 +08:00
Soulter
b0e9de6951 perf: 增加DIFY超时时间 #422 2025-02-08 12:58:54 +08:00
Soulter
3489522bbb feat: 支持展示插件是否有更新 2025-02-08 12:22:36 +08:00
Soulter
96237abc03 fix: 当群聊自动回复时,不会带上人格的Prompt #419 2025-02-08 10:17:43 +08:00
Xu Void
7155b4f0ac Update default.py 2025-02-08 10:16:31 +08:00
Soulter
a8b2b09e0f v3.4.22 2025-02-08 00:01:47 +08:00
Soulter
6858b8c555 perf: 当图片数据为空时不加入上下文 #379 2025-02-07 23:57:25 +08:00
Soulter
0e493b1a0e Merge pull request #411 from zhaolj/fix-bug-#298
fix bug #298
2025-02-07 23:39:03 +08:00
Soulter
37d478f970 fix: 移除了分段回复llm提示词辅助 2025-02-07 23:21:05 +08:00
zhaolj
7d0d42a49f fix bug #298 2025-02-07 22:57:49 +08:00
Soulter
0eb1684ef1 fix: 修复 openai_source 尝试弹出最早的记录失败的问题 2025-02-07 22:38:04 +08:00
Soulter
9b0b723143 fix: 联网搜索失败,函数调用无返回值 #342 2025-02-07 22:07:56 +08:00
Soulter
532bc6e1e6 fix: Google Search 报 429 错误时,放宽 Exception 至其他搜索引擎 #405 2025-02-07 21:32:06 +08:00
Soulter
fe3ed4c454 fix: 自部署文转图不生效 #352 2025-02-07 20:24:11 +08:00
Soulter
b5ec89e586 fix: 插件错误信息点击关闭没反应 #394 2025-02-07 20:05:45 +08:00
Soulter
895e7397c2 remove: 移除了 put_history_to_prompt。当主动回复时,将群聊记录将自动放入prompt,当未主动回复但是开启群聊增强时,群聊记录将放入system prompt 2025-02-07 20:00:30 +08:00
Soulter
59b767957a fix: 400 Bad Request: The browser (or proxy) sent a request that this server could not understand. #396 2025-02-07 18:26:31 +08:00
Soulter
17d4bf8f22 perf: 管理面板优化新增列表项的提示 2025-02-06 20:19:53 +08:00
Soulter
836be3b097 update: changelogs 2025-02-06 18:51:47 +08:00
Soulter
310415bea9 feat: 聊天增强图像转述支持自定义 Provider id 2025-02-06 18:49:16 +08:00
Soulter
aafc1276a9 v3.4.21 2025-02-06 18:34:43 +08:00
Soulter
2993e794cc perf: hint 2025-02-06 17:45:15 +08:00
Soulter
58cb9cfb2d chore: clean code 2025-02-06 17:43:04 +08:00
Soulter
fbdf0901d5 fix: 修复reminder时区问题 2025-02-06 17:41:34 +08:00
Soulter
af8c81b621 feat: 支持重载插件 2025-02-06 17:27:53 +08:00
Soulter
06b5275e48 perf: 增加报错显示 2025-02-06 16:43:40 +08:00
Soulter
ad95572d5f perf: 更好的 list 可视化 2025-02-06 15:59:45 +08:00
Xu Void
0021cfc4bc 新增过滤掉正则表达式内容
Fixes #338

新增过滤掉正则表达式内容

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/Soulter/AstrBot/issues/338?shareId=XXXX-XXXX-XXXX-XXXX).
2025-02-06 15:28:28 +08:00
Soulter
aebc7850f4 fix: openrouter 报错 no endpoints found that support tool use #371 2025-02-06 15:25:15 +08:00
Soulter
1b7efbc607 支持列表展示插件市场 2025-02-06 15:18:11 +08:00
Soulter
3800e96d14 fix: 修复metadata不生效的问题
feat: 支持查看插件行为
2025-02-06 15:10:24 +08:00
Soulter
461f1bb07c feat: 支持插件handler优先级 2025-02-06 12:35:43 +08:00
Soulter
7d4c07e4f6 feat: 支持设置 timeout 2025-02-06 12:31:39 +08:00
Soulter
31b788f463 fix: 修复不支持图片的模型请求异常 2025-02-06 01:50:53 +08:00
Soulter
96ab761f73 fix: 修复reminder无法删除的问题 2025-02-05 22:45:02 +08:00
Soulter
2b3f05c039 update: 优化部分注释 2025-02-05 19:58:48 +08:00
Soulter
f2e8303b66 fix: KeyError _mood_imitation_dialogs_processed 2025-02-05 18:52:55 +08:00
Soulter
2a614b545b fix: 修复可能的 KeyError 2025-02-05 17:17:05 +08:00
Soulter
5c0ab21f68 fix: 修复 /model 异常 2025-02-05 17:05:47 +08:00
Soulter
689d109438 typo: myid -> sid 2025-02-05 16:59:21 +08:00
Soulter
2a6934b283 perf: 无对话状态的提示 2025-02-05 16:56:13 +08:00
Soulter
760cb94e9a v3.4.20 2025-02-05 16:06:52 +08:00
Soulter
2a6cff0013 feat: 支持重命名对话 2025-02-05 16:06:18 +08:00
Soulter
ce578f0417 feat: 支持使用 LLM 辅助分段回复 #338 2025-02-05 15:40:52 +08:00
Soulter
1745bdb9e2 perf: 优化一些问题 2025-02-05 15:39:59 +08:00
Soulter
3f90b89c3c 添加屏蔽无权限指令回复的功能 #361 2025-02-05 15:06:38 +08:00
Soulter
f343e40d15 Merge pull request #370 from Soulter/feat-conversation
feat: 更好的对话管理
2025-02-05 14:56:47 +08:00
Soulter
5cc4be9e65 perf: 优化部分显示问题 2025-02-05 14:51:40 +08:00
Soulter
da5aada002 fix: 修复指令组情况下可能造成多指令出触发的问题 2025-02-05 13:52:53 +08:00
Soulter
07f2ee9ad9 fix: 修复 /reset 指令 2025-02-05 13:33:36 +08:00
Soulter
12f4e1146f feat: 更好的对话管理 2025-02-05 13:26:53 +08:00
Soulter
92c57e5476 fix: 修复级联指令组时出现载入错误的问题 2025-02-05 11:11:04 +08:00
Soulter
a923baacd8 Update README.md 2025-02-05 01:56:09 +08:00
Soulter
999b094d55 Merge pull request #358 from eltociear/patch-1
chore: update main.py
2025-02-05 01:34:04 +08:00
Soulter
d4213f2352 perf: announcement plugin market 2025-02-05 01:19:54 +08:00
Ikko Eltociear Ashimine
3f65c9a066 chore: update main.py
occured -> occurred
2025-02-05 02:18:41 +09:00
Soulter
1d427e2645 perf: 优化插件页面 2025-02-05 01:10:53 +08:00
Soulter
36414c4b00 perf: 优化aiocqhttp适配器对用户非法输入的处理 2025-02-05 00:02:18 +08:00
Soulter
47e253d76c fix: 修复权限过滤算子导致的问题 #350 2025-02-04 23:31:46 +08:00
Soulter
b73cf84df0 v3.4.19 2025-02-04 16:37:15 +08:00
Soulter
a5b885a774 fix: schema 中 object hint 不显示 #290
feat: 优化插件市场的访问
2025-02-04 16:36:00 +08:00
Soulter
0c785413da chore: clean code 2025-02-04 15:51:26 +08:00
Soulter
482d7ef5f7 v3.4.19 2025-02-04 15:47:24 +08:00
Soulter
9f9073c0ff feat: 支持设置所有指令的权限
feat: 插件指令支持设置指令描述
feat: plugin 指令支持查看插件的指令
2025-02-04 15:41:45 +08:00
Soulter
ef05ff4abd fix: 管理员指令 /reset /persona 2025-02-04 13:50:23 +08:00
Soulter
5848aae435 Update README.md 2025-02-04 13:44:02 +08:00
Soulter
fb06f33de0 Update README.md 2025-02-04 12:51:17 +08:00
Soulter
0d7ddb149e fix: 修复请求 gemini 推理模型出现 candidates 错误的问题 #333 2025-02-04 00:30:23 +08:00
Soulter
4f2d7b9c4e feat: 适配 Azure OpenAI #332 2025-02-03 23:59:04 +08:00
Soulter
c02ed96f6f perf: gewechat 服务端回调接口默认暴露在所有地址 2025-02-03 18:51:19 +08:00
Soulter
3b2ac891b2 fix: 修复限流器不可用的问题 #263 2025-02-03 18:51:19 +08:00
Soulter
ef0108881b Update Dockerfile 2025-02-03 17:48:17 +08:00
Soulter
af48975a6b chore: v3.4.18 2025-02-03 16:14:27 +08:00
Soulter
6441b149ab fix: 修复主动概率回复关闭后仍然回复的问题 #317 2025-02-03 14:33:53 +08:00
Soulter
f8892881f8 fix: 尝试修复 gewechat 群聊收不到 at 的回复 #294 2025-02-03 14:28:14 +08:00
Soulter
228aec5401 perf: 移除了默认人格 2025-02-03 14:17:45 +08:00
Soulter
68ad48ff55 fix: 修复HTTP代理删除后不生效 #319 2025-02-03 14:11:50 +08:00
Soulter
541ba64032 fix: 调用Gemini API输出多余空行问题 #318 2025-02-03 13:27:56 +08:00
Soulter
2d870b798c feat: 添加硅基流动模版 2025-02-03 13:24:22 +08:00
Soulter
0f1fe1ab63 fix: 硅基流动 not a vlm 和 tool calling not supported 报错 #305 # 291
perf: 安装和更新插件后全量重启避免奇奇怪怪的bug
feat: 支持 /tool off_all 停用所有函数工具
2025-02-03 13:20:49 +08:00
Soulter
73cc86ddb1 perf: 回复时艾特发送者之后添加空格或换行 #312 2025-02-03 12:04:26 +08:00
Soulter
23128f4be2 perf: 主动回复不支持 qq_official 的 hint 2025-02-03 12:00:05 +08:00
Soulter
92200d0e82 fix: docker容器内时区不对 2025-02-03 01:15:09 +08:00
Soulter
d6e8655792 fix: 抱错时首先移除 tool 2025-02-02 23:17:59 +08:00
Soulter
37076d7920 perf: siliconcloud 不支持 tool 的模型 2025-02-02 23:05:36 +08:00
Soulter
78347ec91b perf: 当人格长度为1时设置默认人格
feat: 支持取消人格
2025-02-02 22:36:50 +08:00
Soulter
9ded102a0a chore: v3.4.17 2025-02-02 20:39:26 +08:00
Soulter
59b7d8b8cb chore: clean code 2025-02-02 20:15:57 +08:00
Soulter
f5b97f6762 perf: 优化 404 提示 2025-02-02 19:55:32 +08:00
Soulter
d47da241af feat: openai tts 更换模型 #300 2025-02-02 19:47:39 +08:00
Soulter
4611ce15eb feat: [beta] 支持群聊内基于概率的主动回复 2025-02-02 19:23:46 +08:00
Soulter
aa8c56a688 fix: 相同type的provider共享了记忆 2025-02-02 19:13:47 +08:00
Soulter
ef44d4471a feat: 增加模型响应后的插件钩子
remove: 移除了默认的r1过滤
2025-02-02 16:42:21 +08:00
Soulter
5581eae957 fix: deepseek-r1模型存在遗留“</think>”的问题 #279
Open
2025-02-02 14:59:17 +08:00
Soulter
ec46dfaac9 perf: 人格情景在发现格式不对时仍然加载而不是跳过 #282 2025-02-02 14:59:17 +08:00
Soulter
6042a047bd 修复Gemini函数调用时,parameters为空对象导致的错误
fix: 修复Gemini函数调用时,parameters为空对象导致的错误
2025-02-02 14:50:34 +08:00
Soulter
6ca9e2a753 perf: websearch 可选配置引用链接 #287 2025-02-02 14:42:13 +08:00
Camreishi
618eabfe5c fix: 修复Gemini函数调用时,parameters为空对象导致的错误
Closes #288
2025-02-02 13:25:08 +08:00
Soulter
bb5db2e9d0 fix: 修复弹出记录报错的问题 #272 2025-02-02 13:24:05 +08:00
Soulter
97e4d169b3 perf: 未启用模型提供商时的异常处理 2025-02-02 11:23:33 +08:00
Soulter
50e44b1473 perf: 移除默认人格 2025-02-02 11:12:17 +08:00
Soulter
38588dd3fa update compose 2025-02-02 00:11:55 +08:00
Soulter
d183388347 perf: 去除gewechat默认配置 2025-02-01 23:20:25 +08:00
Soulter
1e69d59384 fix: 配置提示 typo 2025-02-01 22:56:34 +08:00
Soulter
00f008f94d Update compose.yml 2025-02-01 21:31:24 +08:00
Soulter
3c28001a74 v3.4.16 2025-02-01 19:31:59 +08:00
Soulter
76a6218be6 fix: 修复webui无法从本地上传插件的问题 2025-02-01 19:31:29 +08:00
Soulter
6c1de1bbd6 Update README.md 2025-02-01 16:19:01 +08:00
Soulter
d7678081da perf: Provider 重复时不直接报错闪退 #265 2025-02-01 14:36:41 +08:00
Soulter
5e4ba563cb perf: 弱化更新报错 #267 2025-02-01 14:29:39 +08:00
Soulter
8afbe77b0a Update README.md 2025-02-01 12:11:58 +08:00
Soulter
2ef139b59a fix: 修复每次启动astrbot都需要微信扫码的问题 2025-01-31 01:28:49 +08:00
Soulter
1f0d2d9b89 fix: QQ官方机器人开启 reply with metion 和 reply with quote 后,无法正常回复消息 #244 2025-01-30 01:36:25 +08:00
Soulter
37a1f144ab chore: update changelog of 3.4.15 2025-01-30 00:32:50 +08:00
Soulter
9a7a654596 perf: 插件处于禁用状态时其所属的函数调用工具不可被启用 #254 2025-01-30 00:27:10 +08:00
Soulter
9abccd63cf chore: remove stt.py 2025-01-29 23:47:50 +08:00
Soulter
93fea77182 chore: bump to v3.4.15 2025-01-29 23:43:09 +08:00
Soulter
19797243f6 perf: 增加插件链接 2025-01-29 19:56:09 +08:00
Soulter
c9c733d925 Merge branch 'dev' 2025-01-29 19:43:52 +08:00
Soulter
a7d7678c78 fix: 修复白名单为空时依然终止事件 #259 2025-01-29 17:17:27 +08:00
Soulter
c0911921c7 feat: 配置Schema以及插件支持配置 2025-01-29 16:54:57 +08:00
Soulter
4a4241d57a Update README.md 2025-01-29 13:26:51 +08:00
Soulter
c9426bb6eb config 2025-01-29 12:25:54 +08:00
Soulter
db4abd169a fix: 优化分段回复 2025-01-28 14:42:15 +08:00
Soulter
80b6958599 fix: 修复 config validator 不起效的问题 2025-01-28 14:18:21 +08:00
Soulter
80058c781a fix: 修复r1思考标签问题和分段回复间隔时间问题 2025-01-28 14:03:10 +08:00
Soulter
44bd2e36f3 Update README.md 2025-01-28 02:15:11 +08:00
Soulter
3589a5e5be perf: 强化ltm异常处理 2025-01-27 21:47:35 +08:00
Soulter
13ef033f0e fix: 群聊增强的参数类型转换 2025-01-27 21:40:20 +08:00
Soulter
3f8c68bbca fix: f-string expression part cannot include a backslash
long_term_memory.py, line 69
2025-01-27 21:01:50 +08:00
Soulter
4275cea82b chore: v3.4.14 2025-01-27 20:09:03 +08:00
Soulter
a0bcb5339a perf: 自动删除 deepseek-r1 模型自带的 think 标签 2025-01-27 20:04:39 +08:00
Soulter
43deec4a4b Merge pull request #255 from Soulter/feat-ltm
支持记录非唤醒状态下群聊历史记录
2025-01-27 20:02:43 +08:00
Soulter
2bc433a30b feat: 支持记录非唤醒状态下群聊历史记录 2025-01-27 20:00:32 +08:00
Soulter
eb2b395932 perf: /t2i 即时生效 2025-01-27 19:33:38 +08:00
Soulter
2bfd1c0bf2 perf: 自动移除 ollama 不支持 tool 的模型的 tool 请求 2025-01-27 19:25:28 +08:00
Soulter
7228c4b13f fix: 修复 TTS 部分变量名错误导致请求失败 2025-01-27 18:45:34 +08:00
Soulter
9351d7471f perf: 优化 gewechat 消息下发异常处理 2025-01-27 18:11:31 +08:00
Soulter
1cf49998bc Update README.md 2025-01-27 11:34:27 +08:00
Soulter
6ae86597e8 chore: v3.4.13 2025-01-26 16:51:13 +08:00
Soulter
c578ff25bd fix: stt_enabled 未初始化 #252 2025-01-26 16:51:02 +08:00
Soulter
2934a3e3be chore: logo 2025-01-26 15:18:23 +08:00
Soulter
ceaa69da75 feat: 支持消息分段回复 2025-01-26 13:45:32 +08:00
Soulter
fa8e731576 Update README.md 2025-01-25 22:45:47 +08:00
Soulter
685c0a106a perf: use pysilk instead of pilk 避免构建问题 2025-01-25 20:18:40 +08:00
Soulter
7f539090dd perf: 更新项目时连带更新依赖 2025-01-25 20:04:28 +08:00
Soulter
2089273f95 Merge pull request #251 from Soulter/feat-tts
适配 OpenAI TTS API,并支持 Napcat,Gewechat,Lagrange 的语音输出
2025-01-25 19:51:22 +08:00
Soulter
838bb4c7ad chore: remove duration 2025-01-25 19:49:53 +08:00
Soulter
637acd1a12 feat: 适配 OpenAI TTS API,并支持 Napcat,Gewechat,Lagrange 的语音输出 2025-01-25 19:46:00 +08:00
Soulter
03fa9a847f feat: gewechat 支持语音、图片 2025-01-25 16:34:40 +08:00
Soulter
d488c88e78 feat: 支持路径映射,解决docker部署两端文件系统不一致导致的富媒体文件路径不存在问题 2025-01-24 14:08:08 +08:00
Soulter
baae842210 fix: napcat 下语音消息接收异常 2025-01-24 13:41:13 +08:00
Soulter
ec1fb838b6 perf: notice 2025-01-22 21:38:05 +08:00
Soulter
13281179df perf: notice 2025-01-22 21:36:28 +08:00
Soulter
276a42c9a1 Bump to 3.4.11 2025-01-22 21:16:24 +08:00
Soulter
7a70a730ba perf: 任务报错后的优雅报错输出 2025-01-22 21:14:26 +08:00
Soulter
d0fe59631c perf: 优化更新项目时重启可能会导致Address already in use的问题 2025-01-22 20:57:15 +08:00
Soulter
106892e933 fix: 修复appid保存的问题和部分群聊at失效的问题和群聊@的sender username显示异常的问题 2025-01-22 20:34:52 +08:00
Soulter
19543a41b3 Update README.md 2025-01-22 19:56:07 +08:00
Soulter
b172b760ab feat: 为平台和提供商适配器添加默认 ID 配置 #248 2025-01-22 16:52:34 +08:00
Soulter
4b5d49cb41 Bump to 3.4.10 2025-01-22 00:19:20 +08:00
Soulter
3fd35b6058 feat: 管理面板更新面板按钮 #245 2025-01-22 00:17:43 +08:00
Soulter
5f86c4ab99 perf: 增强 LLM 请求错误处理 #243 2025-01-21 16:29:19 +08:00
Soulter
c94a7f6629 perf: 针对 api_base 的明显提示,修改 ollama 模板的api_base #247 2025-01-21 16:15:04 +08:00
Soulter
7d6beb4141 fix: QQ 图片发送不了 #246 2025-01-21 16:12:10 +08:00
Soulter
e2117e690a feat: 支持登出gewechat 2025-01-21 13:12:09 +08:00
Soulter
fb791290e2 fix: 添加gewechat适配器过滤器 2025-01-21 12:39:57 +08:00
Soulter
5dd1488b5d perf: 优化webui和主程序更新的协调
fix: 修复某些请求不能正确应用代理的问题
2025-01-21 01:08:15 +08:00
Soulter
529cd64d82 perf: help显示AstrBot和webui版本 2025-01-21 00:10:59 +08:00
Soulter
d2bd3e8da8 bump to v3.4.9 2025-01-20 23:35:34 +08:00
Soulter
e42ce7dd86 perf: 优化了用户体验 2025-01-20 23:27:13 +08:00
Soulter
40709462ee chore: bump domain to astrbot.app 2025-01-20 19:02:54 +08:00
Soulter
2ad6c01a4d Update README.md 2025-01-20 15:48:39 +08:00
Soulter
70c12e788e feat: LLM额外唤醒词与机器人唤醒词冲突时的处理 2025-01-20 10:22:25 +08:00
Soulter
1713791c90 docs: update webui demo 2025-01-20 00:46:29 +08:00
Soulter
9aa23fd412 Update README.md 2025-01-19 21:32:42 +08:00
Soulter
e4ba09cd93 chore: remove package-lock.json 2025-01-19 18:20:40 +08:00
Soulter
171fdf1fbc fix: 消息链无元素时仍然插入了@和回复 2025-01-18 23:25:42 +08:00
Soulter
01f4e0b961 feat: gewechat 主动消息 2025-01-18 22:31:17 +08:00
Soulter
be2d5a91c7 chore: bump to v3.4.8 2025-01-18 22:19:35 +08:00
Soulter
a1d89d9478 Merge pull request #242 from Soulter/feat-gewechat
初步接入 gewechat 文字交互
2025-01-18 22:16:53 +08:00
Soulter
98d1dc3b65 feat: 初步接入 gewechat 文字交互 2025-01-18 22:01:36 +08:00
Soulter
b80eb3acc0 feat: 支持回复时 At 和引用发送者 #241 2025-01-18 17:31:11 +08:00
Soulter
05ccc1995b fix: 清除残留的 personalities 2025-01-18 17:31:11 +08:00
Soulter
0de244889e chore: gitsponsors 2025-01-18 10:54:37 +08:00
Soulter
e6c5c3a493 chore: bump to v3.4.7 2025-01-16 11:26:05 +08:00
Soulter
164aa2ccd2 Merge pull request #240 from Soulter/feat-better-persona
feat: 更好的人格情景管理
2025-01-16 11:20:28 +08:00
Soulter
f1599e26b3 perf: webchat 主动信息 2025-01-16 11:19:02 +08:00
Soulter
ed64a4d32d chore: 整理hint 2025-01-16 11:11:30 +08:00
Soulter
2ee4b431d4 fix: 无tool导致的报错 #239 2025-01-15 11:16:31 +08:00
Soulter
cd8a73ed19 feat: 更好的人格情景管理和管理面板支持删除列表默认模版项 2025-01-14 21:08:57 +08:00
Soulter
e6c985ce4e feat: 优化WebChat长连接的逻辑 2025-01-13 12:42:32 +08:00
Soulter
a20446aeb9 🎉 chore: bump to v3.4.6 2025-01-13 02:17:23 +08:00
Soulter
7b23d76559 feat: 支持并完善服务提供商默认配置模板接口 2025-01-13 02:05:57 +08:00
Soulter
8315cf5818 perf: 面板文件更新检查和引导提示和AboutPage 2025-01-12 13:01:40 +08:00
Soulter
ed16265bde fix: 更新官方文档链接并优化管理面板版本检查日志 2025-01-12 12:23:27 +08:00
Soulter
dff205faf6 feat: 添加聊天功能路由和更新管理面板命令 2025-01-12 12:18:19 +08:00
Soulter
9aae8aee0c Update README.md 2025-01-12 11:45:39 +08:00
Soulter
7c818ced2b perf: 文件和语音功能适配 Lagrange 2025-01-12 11:44:33 +08:00
Soulter
218e887558 fix: download_file 修复 SSL 连接错误处理 2025-01-12 11:44:33 +08:00
Soulter
a68860b35a chore: compress the banner 2025-01-12 10:52:17 +08:00
Soulter
82d4d43383 🎉 Bump to v3.4.5 2025-01-11 23:35:22 +08:00
Soulter
94618e8feb feat: 添加 aiodocker 依赖 2025-01-11 22:02:15 +08:00
Soulter
55de7d4494 🎉 Bump to v3.4.5 2025-01-11 21:40:48 +08:00
Soulter
7ed639f741 🎉 bump to v3.4.5 2025-01-11 21:06:06 +08:00
Soulter
41f2870c29 Merge pull request #236 from Soulter/feat-stt
支持 Speech To Text,并适配腾讯修改过的 Silk 语音格式
2025-01-11 21:00:04 +08:00
Soulter
ba198490fa feat: 支持自部署 Whisper 模型 2025-01-11 20:31:21 +08:00
Soulter
0f9ab082ab perf: 优化webchat,没有结果返回时的反馈 2025-01-11 19:45:42 +08:00
Soulter
97b58965f2 feat: webchat可显示Provider状态 2025-01-11 19:31:56 +08:00
Soulter
f2566c68e3 feat: 按 K 语音 2025-01-11 19:07:26 +08:00
Soulter
a456bf5449 fix: 初始化reminder时的一些问题 2025-01-11 18:55:18 +08:00
Soulter
a09998f910 feat: webchat 支持语音输入 2025-01-11 18:54:40 +08:00
Soulter
be662b913c feat: 支持 Whisper STT,并适配 Tencent 语音格式 2025-01-11 17:19:28 +08:00
Soulter
e7ddc8448d perf: 代码执行器在成功执行后清空文件buffer 2025-01-11 11:31:56 +08:00
Soulter
29374f8d8a fix: 修复 /dashbord_update 指令 2025-01-11 00:25:02 +08:00
Soulter
359b971103 Merge pull request #235 from Soulter/feat-webchat
WebChat 支持
2025-01-11 00:17:18 +08:00
Soulter
fbdb1ae208 chore: bump to v3.4.4 2025-01-11 00:14:08 +08:00
Soulter
22c13c1eff perf: webchat支持传图 2025-01-11 00:06:19 +08:00
Soulter
5fc63aeaf1 perf: ui 2025-01-10 22:45:14 +08:00
Soulter
d4f32673ab fix: 修复持久化问题 2025-01-10 22:08:43 +08:00
Soulter
480dffb51b feat: 初步实现 webchat 页面 2025-01-10 21:48:15 +08:00
Soulter
966df00124 feat: 支持从管理面板(控制台页)手动安装 pip 库 2025-01-10 15:35:57 +08:00
Soulter
3e2b4bc727 feat: 支持动态设置会话变量以适用 Dify 输入变量 2025-01-10 12:32:20 +08:00
Soulter
5929a8d42b Update README.md 2025-01-09 23:11:11 +08:00
Soulter
f8ab40eb39 chore: 上传管理面板package.json 2025-01-09 22:25:46 +08:00
Soulter
55e9233b93 docs: v3.4.3 changelog 2025-01-09 22:19:11 +08:00
Soulter
b7277b51fd feat: 管理面板支持显示不在metadata中的配置 2025-01-09 22:03:53 +08:00
Soulter
1fa9111b2b perf: 进一步防止llm递归调用 2025-01-09 22:03:22 +08:00
Soulter
90a9e496d9 feat: 适配器类插件支持设置默认配置模板 2025-01-09 19:45:18 +08:00
Soulter
2a7dce1eb0 chore: clean code 2025-01-09 16:34:39 +08:00
Soulter
0c0841cc03 fix: websearch 在 cmd_config 中失效的问题 2025-01-09 16:33:58 +08:00
Soulter
4c9fe016bf fix: test_pipeline 2025-01-09 16:00:43 +08:00
Soulter
acc90f140c chore: bump dashboard_release_url 2025-01-09 15:50:24 +08:00
Soulter
68a7bc3930 Merge pull request #232 from Soulter/feat-python-interpreter
初步实现代码执行器
2025-01-09 15:43:40 +08:00
Soulter
12ea64be0e fix: dashboard command bug 2025-01-09 15:42:04 +08:00
Soulter
7f30a673f7 fix: 修复 qq_official 无法发图 2025-01-09 15:20:54 +08:00
Soulter
897e100c32 Merge pull request #234 from Soulter/233-gemini-native-support
支持通过 Google GenAI 访问 Gemini 模型
2025-01-09 14:23:44 +08:00
Soulter
0d4ad5cb31 fix: 修复 APScheduler 任务错过后不执行的问题 2025-01-09 14:23:07 +08:00
Soulter
b124bd0d0e feat: 支持通过 Google GenAI 访问 Gemini 模型 2025-01-09 14:05:48 +08:00
Soulter
6bc2f84602 Update README.md
qingcloud 在新网的账户余额不足导致原域名无法续费
2025-01-09 10:35:02 +08:00
Soulter
d787a28c40 feat: 支持使用 /dashboard update 更新管理面板 2025-01-09 00:59:28 +08:00
Soulter
6b078a5731 cd: build dashboard files automatically 2025-01-09 00:57:48 +08:00
Soulter
17dddbfe21 chore: 禁用插件 2025-01-08 23:34:54 +08:00
Soulter
3ff3c9e144 perf: 检测到docker不可用时自动禁用本插件 2025-01-08 23:32:49 +08:00
Soulter
f5a37d82cc Merge branch 'master' into feat-python-interpreter 2025-01-08 23:13:52 +08:00
Soulter
d3d428dc9d fix: 管理面板支持禁用/启用插件 2025-01-08 23:04:03 +08:00
Soulter
8dc8c5b5dc feat: 支持对插件禁用/启用 2025-01-08 22:28:20 +08:00
Soulter
e6b06f914b perf: provider 偏好项记忆 2025-01-08 20:46:34 +08:00
Soulter
4dc502a8b6 fix: 修复事件监听器会让wakestage失效的问题 2025-01-08 20:24:01 +08:00
Soulter
b1d1a13d5f perf: 支持图片输入 2025-01-08 19:56:03 +08:00
Soulter
75cc4cac5a perf: 代码执行器添加部分控制指令,添加更多可用库 2025-01-08 13:26:16 +08:00
Soulter
1b7e4fbbdc perf: 退出时关闭 aiohttp client session 2025-01-08 12:43:34 +08:00
Soulter
9789e2f6c1 perf: 代码执行器请求llm不持久化历史记录 2025-01-08 02:12:35 +08:00
Soulter
b8fb0bee24 feat: 初步实现代码执行器 #210 2025-01-08 02:10:27 +08:00
Soulter
419f77e245 Update README.md 2025-01-07 20:56:25 +08:00
Soulter
59b1c3473b Merge pull request #230 from Soulter/feat-dify
接入 Dify
2025-01-07 20:14:33 +08:00
Soulter
6db58ca375 perf: 优化在prompt为空的情况下不请求provider 2025-01-07 20:01:47 +08:00
Soulter
4832b342b0 Merge branch 'master' into feat-dify 2025-01-07 19:59:54 +08:00
Soulter
6cec542402 feat: 初步接入 Dify 2025-01-07 19:56:18 +08:00
Soulter
9644791783 feat: kdb 2024-12-30 18:06:09 +08:00
Soulter
5031c307d1 update: readme 2024-12-26 23:39:29 +08:00
Soulter
aa49539e3e chore: fix test 2024-12-26 23:33:40 +08:00
Soulter
7b4118493b chore: fix test 2024-12-26 23:15:10 +08:00
Soulter
d1cc9ba4ce chore: update test workflow 2024-12-26 23:09:11 +08:00
Soulter
e0e92139d7 fix: test workflow 2024-12-26 23:07:50 +08:00
Soulter
62039392bb chore: fix test workflow 2024-12-26 23:06:30 +08:00
Soulter
b72c69892e test: dashboard test 2024-12-26 22:59:17 +08:00
Soulter
e6205e9aad ci: update workflow 2024-12-25 17:18:29 +08:00
Soulter
b8a6fb1720 chore: update tests 2024-12-25 12:50:29 +08:00
Soulter
7c06d82f27 perf: plugin manager 重复 reload 释放资源 2024-12-25 12:50:29 +08:00
Soulter
d92cb0f500 perf: 当没有provider时直接返回 2024-12-25 12:50:29 +08:00
Soulter
7fa72f2fe9 perf: adapt glm-4v-flash 2024-12-24 14:08:20 +08:00
Soulter
21d480a3b5 bugfixes 2024-12-22 05:31:29 +08:00
Soulter
771c045844 feat: 可配置是否启用白名单 2024-12-22 05:18:27 +08:00
Soulter
e6ce484c15 perf: 不加载已经outdated的reminder 2024-12-22 05:06:15 +08:00
Soulter
102a92f62d perf: 移动对 prompt 的内置修改的逻辑 2024-12-21 18:39:10 +08:00
Soulter
6c7ac70701 Bump version to v3.4.2 2024-12-21 16:40:04 +08:00
Soulter
9d8372289f fix: fstring format error #226 2024-12-21 16:38:53 +08:00
Soulter
766f6a1ba2 perf: use request_llm 2024-12-21 16:35:16 +08:00
Soulter
193ff24f4c feat: 添加发送消息后的事件钩子 2024-12-20 16:31:36 +08:00
Soulter
c675017374 feat: 新增LLM请求事件钩子和装饰消息结果钩子 2024-12-19 21:33:03 +08:00
Soulter
86cb852507 perf: llm-tuner adapter 检查路径 2024-12-18 21:25:04 +08:00
Soulter
73494e0d7d perf: 使用 astrbot-registry 下载面板静态资源 2024-12-18 21:24:39 +08:00
Soulter
ec61aa1b6f Merge pull request #224 from Soulter/dashboard
迁移 AstrBot Dashboard 源代码至 AstrBot
2024-12-17 23:45:13 +08:00
Soulter
6df0e78b22 upload: dashboard from Soulter/AstrBot-Dashboard 2024-12-17 23:40:32 +08:00
Soulter
63c604359b fix: update 2024-12-16 22:53:23 +08:00
Soulter
08212588a0 chore: update docker ci/cd workflow 2024-12-16 21:12:02 +08:00
245 changed files with 14683 additions and 1635 deletions

View File

@@ -16,3 +16,5 @@ venv*/
ENV/
.conda/
README*.md
dashboard/
data/

View File

@@ -2,6 +2,7 @@ on:
push:
tags:
- 'v*'
workflow_dispatch:
name: Auto Release
@@ -14,6 +15,15 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4
- name: Dashboard Build
run: |
cd dashboard
npm install
npm run build
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
@@ -21,4 +31,5 @@ jobs:
- name: Create Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"

View File

@@ -1,7 +1,14 @@
name: Run tests and upload coverage
on:
push
push:
branches:
- master
paths-ignore:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
workflow_dispatch:
jobs:
test:
@@ -21,17 +28,16 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
mkdir data
mkdir data/plugins
mkdir data/config
mkdir temp
- name: Run tests
run: |
export LLM_MODEL=${{ secrets.LLM_MODEL }}
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4

View File

@@ -1,8 +1,9 @@
name: Docker Image CI/CD
on:
release:
types: [published]
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
@@ -35,7 +36,7 @@ jobs:
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"

11
.gitignore vendored
View File

@@ -16,4 +16,13 @@ addons/plugins
tests/astrbot_plugin_openai
chroma
chroma
node_modules/
.DS_Store
package-lock.json
package.json
venv/*
packages/python_interpreter/workplace
.venv/*
.conda/

View File

@@ -12,7 +12,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python -m pip install -r requirements.txt
RUN python -m pip install -r requirements.txt --no-cache-dir
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
EXPOSE 6185
EXPOSE 6186

174
README.md
View File

@@ -1,46 +1,158 @@
<p align="center">
<img width=200 src="https://github.com/user-attachments/assets/3dd6a669-0830-4db4-b821-c8b279ea19a6"/>
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
<div align="center">
<h1>AstrBot</h1>
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
</a>
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.soulter.top/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型LLM接入功能的聊天机器人及开发框架。
## ✨ 多消息平台部署
## ✨ 主要功能
1. QQ 群、QQ 频道、微信、Telegram
2. 支持文本转图片Markdown 渲染
## ✨ 多 LLM 配置
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力支持图片理解、语音转文字Whisper
2. **多消息平台接入**。支持接入 QQOneBot、QQ 频道、微信Gewechat、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
1. 适配 OpenAI API支持接入 Gemini、GPT、Llama、Claude、DeepSeek、GLM 等各种大语言模型。
2. 支持 OneAPI 等分发平台。
3. 支持 LLMTuner 载入微调模型。
4. 支持 Ollama 载入自部署模型。
4. 支持网页搜索Web Search
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM无法在聊天页使用大模型。不要再修改 demo 的登录密码了 😭)
## ✨ 管理面板
## ✨ 使用方式
1. 支持可视化修改配置
2. 日志实时查看
3. 简单的信息统计
4. 插件管理
#### Docker 部署
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
#### Windows 一键安装器部署
需要电脑上安装有 Python>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS 部署
社区贡献的部署方式。
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### 手动部署
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## ⚡ 消息平台支持情况
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ(官方机器人接口) | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
| 飞书 | ✔ | 群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
# 🦌 接下来的路线图
> [!TIP]
> 欢迎在 Issue 提出更多建议 <3
- [ ] 完善并保证目前所有平台适配器的功能一致性
- [ ] 优化插件接口
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
- [ ] 完善“聊天增强”部分,支持持久化记忆
- [ ] 规划 i18n
## ❤️ 贡献
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
对于新功能的添加,请先通过 Issue 讨论。
## 🌟 支持
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
## ✨ Demo
> [!NOTE]
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器Beta 测试中✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ 管理面板 ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat在线与机器人交互 ✨_
</div>
## ⭐ Star History
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star这是我维护这个开源项目的动力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
## Disclaimer
1. The project is protected under the `AGPL-v3` opensource license.
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
@@ -51,23 +163,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
3. 表情包理解与回复
4. TTS
-->
## ✨ 云部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ❤️ 贡献
_私は、高性能ですから!_
欢迎任何 Issues/Pull Requests只需要将你的更改提交到此项目 )
对于新功能的添加,请先通过 Issue 讨论。
## 🔭 展望
1. 更强大的 Agent 系统。
2. 打造插件工作流平台。
## ✨ Support
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~

170
README_ja.md Normal file
View File

@@ -0,0 +1,170 @@
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
</p>
<div align="center">
_✨ 簡単に使えるマルチプラットフォーム LLM チャットボットおよび開発フレームワーク ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
<a href="https://astrbot.app/">ドキュメントを見る</a>
<a href="https://github.com/Soulter/AstrBot/issues">問題を報告する</a>
</div>
AstrBot は、疎結合、非同期、複数のメッセージプラットフォームに対応したデプロイ、使いやすいプラグインシステム、および包括的な大規模言語モデルLLM接続機能を備えたチャットボットおよび開発フレームワークです。
## ✨ 主な機能
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換Whisperをサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot、QQ チャンネル、WeChatGewechat、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
> [!TIP]
> 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭)
## ✨ 使用方法
#### Docker デプロイ
公式ドキュメント [Docker を使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) を参照してください。
#### Windows ワンクリックインストーラーのデプロイ
コンピュータに Python>3.10)がインストールされている必要があります。公式ドキュメント [Windows ワンクリックインストーラーを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/windows.html) を参照してください。
#### Replit デプロイ
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS デプロイ
コミュニティが提供するデプロイ方法です。
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/casaos.html) を参照してください。
#### 手動デプロイ
公式ドキュメント [ソースコードを使用して AstrBot をデプロイする](https://astrbot.app/deploy/astrbot/cli.html) を参照してください。
## ⚡ メッセージプラットフォームのサポート状況
| プラットフォーム | サポート状況 | 詳細 | メッセージタイプ |
| -------- | ------- | ------- | ------ |
| QQ(公式ロボットインターフェース) | ✔ | プライベートチャット、グループチャット、QQ チャンネルプライベートチャット、グループチャット | テキスト、画像 |
| QQ(OneBot) | ✔ | プライベートチャット、グループチャット | テキスト、画像、音声 |
| WeChat(個人アカウント) | ✔ | WeChat 個人アカウントのプライベートチャット、グループチャット | テキスト、画像、音声 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | プライベートチャット、グループチャット | テキスト、画像 |
| [WeChat(企業 WeChat)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | プライベートチャット | テキスト、画像、音声 |
| Feishu | ✔ | グループチャット | テキスト、画像 |
| WeChat 対話オープンプラットフォーム | 🚧 | 計画中 | - |
| Discord | 🚧 | 計画中 | - |
| WhatsApp | 🚧 | 計画中 | - |
| Xiaoai 音響 | 🚧 | 計画中 | - |
# 🦌 今後のロードマップ
> [!TIP]
> Issue でさらに多くの提案を歓迎します <3
- [ ] 現在のすべてのプラットフォームアダプターの機能の一貫性を確保し、改善する
- [ ] プラグインインターフェースの最適化
- [ ] GPT-Sovits などの TTS サービスをデフォルトでサポート
- [ ] "チャット強化" 部分を完成させ、永続的な記憶をサポート
- [ ] i18n の計画
## ❤️ 貢献
Issue や Pull Request を歓迎します!このプロジェクトに変更を加えるだけです :)
新機能の追加については、まず Issue で議論してください。
## 🌟 サポート
- このプロジェクトに Star を付けてください!
- [愛発電](https://afdian.com/a/soulter)で私をサポートしてください!
- [WeChat](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)で私をサポートしてください~
## ✨ デモ
> [!NOTE]
> コードエグゼキューターのファイル入力/出力は現在 Napcat(QQ)、Lagrange(QQ) でのみテストされています
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨ Docker ベースのサンドボックス化されたコードエグゼキューターベータテスト中✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多モーダル、ウェブ検索、長文の画像変換設定可能✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然言語タスク ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ プラグインシステム - 一部のプラグインの展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width="600">
_✨ 管理パネル ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_
</div>
## ⭐ Star History
> [!TIP]
> このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
</div>
## スポンサー
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
## 免責事項
1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。
2. WeChat個人アカウントのデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。
3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。
<!-- ## ✨ ATRI [ベータテスト]
この機能はプラグインとしてロードされます。プラグインリポジトリのアドレス:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 《ATRI ~ My Dear Moments》の主人公 ATRI のキャラクターセリフを微調整データセットとして使用した `Qwen1.5-7B-Chat Lora` 微調整モデル。
2. 長期記憶
3. ミームの理解と返信
4. TTS
-->
_私は、高性能ですから!_

View File

@@ -1,13 +1,13 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = [
"AstrBotConfig",
"logger",
"personalities",
"html_renderer",
"llm_tool",
"sp"
]

View File

@@ -1,9 +1,8 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_llm_tool as llm_tool
# event
from astrbot.core.message.message_event_result import (

View File

@@ -1,9 +1,18 @@
from astrbot.core.message.message_event_result import (
MessageEventResult, MessageChain, CommandResult, EventResultType
)
MessageEventResult,
MessageChain,
CommandResult,
EventResultType,
ResultContentType,
)
from astrbot.core.platform import AstrMessageEvent
__all__ = [
'MessageEventResult', 'MessageChain', 'CommandResult', 'EventResultType', 'AstrMessageEvent'
]
"MessageEventResult",
"MessageChain",
"CommandResult",
"EventResultType",
"AstrMessageEvent",
"ResultContentType",
]

View File

@@ -4,12 +4,19 @@ from astrbot.core.star.register import (
register_event_message_type as event_message_type,
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type
register_permission_type as permission_type,
register_custom_filter as custom_filter,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType
from astrbot.core.star.filter.custom_filter import CustomFilter
__all__ = [
'command',
@@ -23,5 +30,12 @@ __all__ = [
'PlatformAdapterTypeFilter',
'PlatformAdapterType',
'PermissionTypeFilter',
'CustomFilter',
'custom_filter',
'PermissionType',
'on_llm_request',
'llm_tool',
'on_decorating_result',
'after_message_sent',
'on_llm_response'
]

View File

@@ -2,4 +2,5 @@ from astrbot.core.platform import (
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *

View File

@@ -1 +1,2 @@
from astrbot.core.provider import Provider, Personality, ProviderMetaData
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse

View File

@@ -1,12 +1,26 @@
import os
import asyncio
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
os.makedirs("data", exist_ok=True)
html_renderer = HtmlRenderer()
astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get('t2i_endpoint', 'https://t2i.soulter.top/text2img')
html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name='astrbot')
if os.environ.get('TESTING', ""):
logger.setLevel('DEBUG')
db_helper = SQLiteDatabase(DB_PATH)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"

View File

@@ -2,7 +2,7 @@ import os
import json
import logging
import enum
from .default import DEFAULT_CONFIG
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
@@ -13,29 +13,72 @@ class RateLimitStrategy(enum.Enum):
DISCARD = "discard"
class AstrBotConfig(dict):
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项
def __init__(self):
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
- 如果传入了 schema将会通过 schema 解析出 default_config此时传入的 default_config 会被忽略。
'''
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None
):
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
object.__setattr__(self, 'config_path', config_path)
object.__setattr__(self, 'default_config', default_config)
object.__setattr__(self, 'schema', schema)
if schema:
default_config = self._config_schema_to_default_config(schema)
if not self.check_exist():
'''不存在时载入默认配置'''
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
has_new = self.check_config_integrity(default_config, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def _config_schema_to_default_config(self, schema: dict) -> dict:
'''将 Schema 转换成 Config'''
conf = {}
def _parse_schema(schema: dict, conf: dict):
for k, v in schema.items():
if v['type'] not in DEFAULT_VALUE_MAP:
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
if 'default' in v:
default = v['default']
else:
default = DEFAULT_VALUE_MAP[v['type']]
if v['type'] == 'object':
conf[k] = {}
_parse_schema(v['items'], conf[k])
else:
conf[k] = default
_parse_schema(schema, conf)
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
@@ -61,7 +104,7 @@ class AstrBotConfig(dict):
'''
if replace_config:
self.update(replace_config)
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
with open(self.config_path, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
@@ -81,4 +124,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(ASTRBOT_CONFIG_PATH)
return os.path.exists(self.config_path)

View File

@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.0"
VERSION = "3.4.30"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -17,39 +17,84 @@ DEFAULT_CONFIG = {
},
"reply_prefix": "",
"forward_threshold": 200,
"enable_id_white_list": True,
"id_whitelist": [],
"id_whitelist_log": True,
"wl_ignore_admin_on_group": True,
"wl_ignore_admin_on_friend": True,
"reply_with_mention": False,
"reply_with_quote": False,
"path_mapping": [],
"segmented_reply": {
"enable": False,
"only_llm_result": True,
"interval_method": "random",
"interval": "1.5,3.5",
"log_base": 2.6,
"words_count_threshold": 150,
"regex": ".*?[。?!~…]+|.+$",
"content_cleanup_rule": "",
},
"no_permission_reply": True,
},
"provider": [],
"provider_settings": {
"enable": True,
"wake_prefix": "",
"web_search": False,
"web_search_link": False,
"identifier": False,
"datetime_system_prompt": True,
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"default_personality": "default",
"prompt_prefix": "",
},
"provider_stt_settings": {
"enable": False,
"provider_id": "",
},
"provider_tts_settings": {
"enable": False,
"provider_id": "",
},
"provider_ltm_settings": {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"active_reply": {
"enable": False,
"method": "possibility_reply",
"possibility_reply": 0.1,
"prompt": "",
"whitelist": []
}
},
"content_safety": {
"also_use_in_response": False,
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
},
"admins_id": [],
"admins_id": [
"astrbot"
],
"t2i": False,
"t2i_word_threshold": 150,
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "astrbot",
"password": "77b90590a8945a7d36c963981a307dc9",
"port": 6185
},
"platform": [],
"wake_prefix": ["/"],
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": ""
"plugin_repo_mirror": "",
"knowledge_db": {},
"persona": [],
}
@@ -71,20 +116,45 @@ CONFIG_METADATA_2 = {
"enable_group_c2c": True,
"enable_guild_direct_message": True,
},
"qq_official_webhook(QQ)": {
"id": "default",
"type": "qq_official_webhook",
"enable": False,
"appid": "",
"secret": "",
"port": 6196
},
"aiocqhtp(QQ)": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
"ws_reverse_host": "",
"ws_reverse_host": "0.0.0.0",
"ws_reverse_port": 6199,
},
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
"gewechat(微信)": {
"id": "gwchat",
"type": "gewechat",
"enable": False,
"base_url": "http://localhost:2531",
"nickname": "soulter",
"host": "这里填写你的局域网IP或者公网服务器IP",
"port": 11451,
},
"lark(飞书)": {
"id": "lark",
"type": "lark",
"enable": False,
"lark_bot_name": "",
"app_id": "",
"app_secret": "",
"domain": "https://open.feishu.cn"
},
},
"items": {
"id": {
"description": "ID",
"type": "string",
"hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义ID 不能重复。",
"hint": "用于在多实例下方便管理和识别。自定义ID 不能重复。",
},
"type": {
"description": "适配器类型",
@@ -126,6 +196,12 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
"obvious_hint": True
}
},
},
"platform_settings": {
@@ -152,6 +228,58 @@ CONFIG_METADATA_2 = {
},
},
},
"no_permission_reply": {
"description": "无权限回复",
"type": "bool",
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
"items": {
"enable": {
"description": "启用分段回复",
"type": "bool",
},
"only_llm_result": {
"description": "仅对 LLM 结果分段",
"type": "bool",
},
"interval_method": {
"description": "间隔时间计算方法",
"type": "string",
"options": ["random", "log"],
"hint": "分段回复的间隔时间计算方法。random 为随机时间log 为根据消息长度计算,$y=log_{log\_base}(x)$x为字数y的单位为秒。",
},
"interval": {
"description": "随机间隔时间(秒)",
"type": "string",
"hint": "`random` 方法用。每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"log_base": {
"description": "对数函数底数",
"type": "float",
"hint": "`log` 方法用。对数函数的底数。默认为 2.6",
},
"words_count_threshold": {
"description": "字数阈值",
"type": "int",
"hint": "超过这个字数的消息不会被分段回复。默认为 150",
},
"regex": {
"description": "正则表达式",
"type": "string",
"obvious_hint": True,
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
},
"content_cleanup_rule": {
"description": "过滤分段后的内容",
"type": "string",
"obvious_hint": True,
"hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'<regex>', '', text)",
},
},
},
"reply_prefix": {
"description": "回复前缀",
"type": "string",
@@ -162,11 +290,16 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。",
},
"enable_id_white_list": {
"description": "启用 ID 白名单",
"type": "bool",
},
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
"items": {"type": "int"},
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID当一条消息没通过白名单时会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID当一条消息没通过白名单时会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -181,12 +314,34 @@ CONFIG_METADATA_2 = {
"description": "管理员私聊消息无视 ID 白名单",
"type": "bool",
},
"reply_with_mention": {
"description": "回复时 @ 发送者",
"type": "bool",
"hint": "启用后,机器人回复消息时会 @ 发送者。实际效果以具体的平台适配器为准。",
},
"reply_with_quote": {
"description": "回复时引用消息",
"type": "bool",
"hint": "启用后,机器人回复消息时会引用原消息。实际效果以具体的平台适配器为准。",
},
"path_mapping": {
"description": "路径映射",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
},
},
},
"content_safety": {
"description": "内容安全",
"type": "object",
"items": {
"also_use_in_response": {
"description": "对大模型响应安全审核",
"type": "bool",
"hint": "启用后,大模型的响应也会通过内容安全审核。",
},
"baidu_aip": {
"description": "百度内容审核配置",
"type": "object",
@@ -225,38 +380,86 @@ CONFIG_METADATA_2 = {
},
},
"provider_group": {
"name": "大语言模型",
"name": "服务提供商",
"metadata": {
"provider": {
"description": "大语言模型配置",
"description": "服务提供商配置",
"type": "list",
"config_template": {
"openai": {
"id": "default",
"id": "openai",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "",
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {
"model": "gpt-4o-mini",
},
},
"azure_openai": {
"id": "azure",
"type": "openai_chat_completion",
"enable": True,
"api_version": "2024-05-01-preview",
"key": [],
"api_base": "",
"timeout": 120,
"model_config": {
"model": "gpt-4o-mini",
},
},
"xAI": {
"id": "xai",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {
"model": "grok-2-latest",
},
},
"anthropic(claude)": {
"id": "claude",
"type": "anthropic_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
"model_config": {
"model": "claude-3-5-sonnet-latest",
"max_tokens": 4096,
},
},
"ollama": {
"id": "ollama_default",
"type": "openai_chat_completion",
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434",
"api_base": "http://localhost:11434/v1",
"model_config": {
"model": "llama3.1-8b",
},
},
"gemini": {
"gemini(OpenAI兼容)": {
"id": "gemini_default",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
},
},
"gemini(googlegenai原生)": {
"id": "gemini_default",
"type": "googlegenai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
},
@@ -267,20 +470,44 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": [],
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {
"model": "deepseek-chat",
},
},
"zhipu": {
"id": "zhipu_default",
"type": "openai_chat_completion",
"type": "zhipu_chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
},
"siliconflow": {
"id": "siliconflow",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.siliconflow.cn/v1",
"model_config": {
"model": "deepseek-ai/DeepSeek-V3",
},
},
"moonshot(kimi)": {
"id": "moonshot",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {
"model": "moonshot-v1-8k",
},
},
"llmtuner": {
"id": "llmtuner_default",
"type": "llm_tuner",
@@ -290,9 +517,116 @@ CONFIG_METADATA_2 = {
"llmtuner_template": "",
"finetuning_type": "lora",
"quantization_bit": 4,
}
},
"dify": {
"id": "dify_app_default",
"type": "dify",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
"dify_query_input_key": "astrbot_text_query",
"variables": {},
"timeout": 60,
},
"dashscope": {
"id": "dashscope",
"type": "dashscope",
"enable": True,
"dashscope_app_type": "agent",
"dashscope_api_key": "",
"dashscope_app_id": "",
"variables": {},
"timeout": 60,
},
"fastgpt": {
"id": "fastgpt",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.fastgpt.in/api/v1",
"timeout": 60,
},
"whisper(API)": {
"id": "whisper",
"type": "openai_whisper_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "whisper-1",
},
"whisper(本地加载)": {
"whisper_hint": "(不用修改我)",
"enable": False,
"id": "whisper",
"type": "openai_whisper_selfhost",
"model": "tiny",
},
"openai_tts(API)": {
"id": "openai_tts",
"type": "openai_tts_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "tts-1",
"openai-tts-voice": "alloy",
"timeout": "20",
},
"fishaudio_tts(API)": {
"id": "fishaudio_tts",
"type": "fishaudio_tts_api",
"enable": False,
"api_key": "",
"api_base": "https://api.fish-audio.cn/v1",
"fishaudio-tts-character": "可莉",
"timeout": "20",
},
},
"items": {
"variables": {
"description": "工作流固定输入变量",
"type": "object",
"obvious_hint": True,
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
},
# "fastgpt_app_type": {
# "description": "应用类型",
# "type": "string",
# "hint": "FastGPT 应用的应用类型。",
# "options": ["agent", "workflow", "plugin"],
# "obvious_hint": True,
# },
"dashscope_app_type": {
"description": "应用类型",
"type": "string",
"hint": "阿里云百炼应用的应用类型。",
"options": ["agent", "agent-arrange", "dialog-workflow", "task-workflow"],
"obvious_hint": True,
},
"timeout": {
"description": "超时时间",
"type": "int",
"hint": "超时时间,单位为秒。",
},
"openai-tts-voice": {
"description": "voice",
"type": "string",
"obvious_hint": True,
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
},
"fishaudio-tts-character": {
"description": "character",
"type": "string",
"obvious_hint": True,
"hint": "fishaudio TTS 的角色。默认为可莉。更多角色请访问https://fish.audio/zh-CN/discovery",
},
"whisper_hint": {
"description": "本地部署 Whisper 模型须知",
"type": "string",
"hint": "启用前请 pip 安装 openai-whisper 库N卡用户大约下载 2GB主要是 torch 和 cudaCPU 用户大约下载 1 GB并且安装 ffmpeg。否则将无法正常转文字。",
"obvious_hint": True,
},
"id": {
"description": "ID",
"type": "string",
@@ -317,7 +651,8 @@ CONFIG_METADATA_2 = {
"api_base": {
"description": "API Base URL",
"type": "string",
"hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现 404 异常,可以尝试在地址末尾加上 `/v1`。",
"hint": "API Base URL 请在在模型提供商处获得。出现 404 报错,尝试在地址末尾加上 /v1",
"obvious_hint": True,
},
"base_model_path": {
"description": "基座模型路径",
@@ -360,7 +695,35 @@ CONFIG_METADATA_2 = {
"temperature": {"description": "温度", "type": "float"},
"top_p": {"description": "Top P值", "type": "float"},
},
"editable": True,
},
"dify_api_key": {
"description": "API Key",
"type": "string",
"hint": "Dify API Key。此项必填。",
},
"dify_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Dify API Base URL。默认为 https://api.dify.ai/v1",
},
"dify_api_type": {
"description": "Dify 应用类型",
"type": "string",
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
"options": ["chat", "agent", "workflow"],
},
"dify_workflow_output_key": {
"description": "Dify Workflow 输出变量名",
"type": "string",
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
},
"dify_query_input_key": {
"description": "Prompt 输入变量名",
"type": "string",
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
"obvious": True,
}
},
},
"provider_settings": {
@@ -370,7 +733,8 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "是否启用大语言模型聊天。默认启用",
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
"obvious_hint": True,
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
@@ -380,22 +744,31 @@ CONFIG_METADATA_2 = {
"web_search": {
"description": "启用网页搜索",
"type": "bool",
"hint": "能访问 Google 时效果最佳。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
"obvious_hint": True,
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
},
"web_search_link": {
"description": "网页搜索引用链接",
"type": "bool",
"obvious_hint": True,
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
},
"identifier": {
"description": "启动识别群员",
"type": "bool",
"obvious_hint": True,
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
},
"datetime_system_prompt": {
"description": "启用日期时间系统提示",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
},
"default_personality": {
"description": "默认人格",
"description": "默认采用的人格情景的名称",
"type": "string",
"hint": "默认人格(情境设置/System Prompt文本。",
"hint": "",
},
"prompt_prefix": {
"description": "Prompt 前缀文本",
@@ -404,6 +777,151 @@ CONFIG_METADATA_2 = {
},
},
},
"persona": {
"description": "人格情景设置",
"type": "list",
"config_template": {
"新人格情景": {
"name": "",
"prompt": "",
"begin_dialogs": [],
"mood_imitation_dialogs": [],
}
},
"tmpl_display_title": "name",
"items": {
"name": {
"description": "人格名称",
"type": "string",
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
"obvious_hint": True,
},
"prompt": {
"description": "设定(系统提示词)",
"type": "text",
"hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
},
"begin_dialogs": {
"description": "预设对话",
"type": "list",
"items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
"obvious_hint": True,
},
"mood_imitation_dialogs": {
"description": "对话风格模仿",
"type": "list",
"items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
"obvious_hint": True,
},
},
},
"provider_stt_settings": {
"description": "语音转文本(STT)",
"type": "object",
"items": {
"enable": {
"description": "启用语音转文本(STT)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID不填则默认第一个STT提供商",
"type": "string",
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
"provider_tts_settings": {
"description": "文本转语音(TTS)",
"type": "object",
"items": {
"enable": {
"description": "启用文本转语音(TTS)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID不填则默认第一个TTS提供商",
"type": "string",
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
"provider_ltm_settings": {
"description": "聊天记忆增强(Beta)",
"type": "object",
"items": {
"group_icl_enable": {
"description": "群聊内记录各群员对话",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
},
"group_message_max_cnt": {
"description": "群聊消息最大数量",
"type": "int",
"obvious_hint": True,
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需要模型支持)",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
},
"image_caption_provider_id": {
"description": "图像转述提供商 ID",
"type": "string",
"obvious_hint": True,
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
},
"image_caption_prompt": {
"description": "图像转述提示词",
"type": "string",
},
"active_reply": {
"description": "主动回复",
"type": "object",
"items": {
"enable": {
"description": "启用主动回复",
"type": "bool",
"obvious_hint": True,
"hint": "启用后会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
},
"whitelist": {
"description": "主动回复白名单",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。",
},
"method": {
"description": "回复方法",
"type": "string",
"options": ["possibility_reply"],
"hint": "回复方法。possibility_reply 为根据概率回复",
},
"possibility_reply": {
"description": "回复概率",
"type": "float",
"obvious_hint": True,
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
},
"prompt": {
"description": "提示词",
"type": "string",
"obvious_hint": True,
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
},
},
},
},
},
},
},
"misc_config_group": {
@@ -413,18 +931,24 @@ CONFIG_METADATA_2 = {
"description": "机器人唤醒前缀",
"type": "list",
"items": {"type": "string"},
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
"obvious_hint": True,
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`则内置指令help等将需要通过您的唤醒前缀来触发。",
},
"t2i": {
"description": "文本转图像",
"type": "bool",
"hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。",
},
"t2i_word_threshold": {
"description": "文本转图像字数阈值",
"type": "int",
"hint": "超出此字符长度的文本将会被转换成图片。字数不能低于 50。",
},
"admins_id": {
"description": "管理员 ID",
"type": "list",
"items": {"type": "int"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
"items": {"type": "string"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。",
},
"http_proxy": {
"description": "HTTP 代理",
@@ -450,7 +974,8 @@ CONFIG_METADATA_2 = {
"plugin_repo_mirror": {
"description": "插件仓库镜像",
"type": "string",
"hint": "插件仓库的镜像地址,用于加速插件的下载。",
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
"obvious_hint": True,
"options": [
"default",
"https://ghp.ci/",

View File

@@ -0,0 +1,119 @@
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation
class ConversationManager():
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
'''新建对话,并将当前会话的对话转移到新对话'''
conversation_id = str(uuid.uuid4())
self.db.new_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
'''切换会话的对话'''
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
del self.session_conversations[unified_msg_origin]
sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
'''获取会话当前的对话 ID'''
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
'''获取会话的对话'''
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
'''获取会话的所有对话'''
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
'''更新会话的对话'''
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
history=json.dumps(history)
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
'''更新会话的对话标题'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin,
cid=conversation_id,
title=title
)
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
'''更新会话的对话 Persona ID'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin,
cid=conversation_id,
persona_id=persona_id
)
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
contexts = []
temp_contexts = []
for record in history:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages

View File

@@ -1,11 +1,12 @@
import traceback
import asyncio
import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
@@ -16,20 +17,24 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = AstrBotConfig()
self.astrbot_config = astrbot_config
self.db = db
if self.astrbot_config['http_proxy']:
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
os.environ['no_proxy'] = 'localhost,127.0.0.1'
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
logger.setLevel(self.astrbot_config['log_level'])
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
else:
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
self.event_queue.closed = False
@@ -37,12 +42,22 @@ class AstrBotCoreLifecycle:
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
self.star_context.platform_manager = self.platform_manager
self.star_context.provider_manager = self.provider_manager
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.conversation_manager = ConversationManager(self.db)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
await self.plugin_manager.reload()
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
@@ -69,18 +84,38 @@ class AstrBotCoreLifecycle:
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
for task in tasks_:
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task):
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def start(self):
self._load()
logger.info("AstrBot 启动完成。")
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self):
self.event_queue.closed = True
for task in self.curr_tasks:
task.cancel()
await self.provider_manager.terminate()
for task in self.curr_tasks:
try:

View File

@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -76,4 +76,38 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
'''通过 user_id 和 cid 获取 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def new_conversation(self, user_id: str, cid: str):
'''新建 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_conversation(self, user_id: str, cid: str):
'''删除 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
'''更新 Conversation 标题'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
'''更新 Conversation Persona ID'''
raise NotImplementedError

View File

@@ -33,16 +33,16 @@ class Stats():
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
'''LLM 聊天时持久化的信息'''
@dataclass
class LLMHistory():
'''LLM 聊天时持久化的信息'''
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision():
'''Deprecated'''
id: str
url_or_path: str
caption: str
@@ -51,4 +51,20 @@ class ATRIVision():
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
timestamp: int = -1
@dataclass
class Conversation():
'''LLM 对话存储
对于网页聊天history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
'''
user_id: str
cid: str
history: str = ""
'''字符串格式的列表。'''
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""

View File

@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
Platform,
Stats,
LLMHistory,
ATRIVision
ATRIVision,
Conversation
)
from . import BaseDatabase
from typing import Tuple
@@ -24,6 +25,37 @@ class SQLiteDatabase(BaseDatabase):
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
'''
PRAGMA table_info(webchat_conversation)
'''
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
'''
)
self.conn.commit()
if not has_persona_id:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
'''
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
@@ -199,7 +231,92 @@ class SQLiteDatabase(BaseDatabase):
c.close()
return Stats(platform, [], [])
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
res = c.fetchone()
c.close()
if not res:
return
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
''', (user_id, cid, history, updated_at, created_at)
)
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
title = row[3]
persona_id = row[4]
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新对话,并且同时更新时间'''
updated_at = int(time.time())
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
''', (history, updated_at, user_id, cid)
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
''', (title, user_id, cid)
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
''', (persona_id, user_id, cid)
)
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
ts = int(time.time())

View File

@@ -35,4 +35,14 @@ CREATE TABLE IF NOT EXISTS atri_vision(
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);

View File

@@ -54,6 +54,7 @@ class ComponentType(Enum):
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -305,7 +306,7 @@ class Image(BaseMessageComponent):
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
id: int
id: T.Union[str, int]
text: T.Optional[str] = ""
qq: T.Optional[int] = 0
time: T.Optional[int] = 0
@@ -324,11 +325,13 @@ class RedBag(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: ComponentType = "Poke"
qq: int
type: str = ""
id: T.Optional[int] = 0
qq: T.Optional[int] = 0
def __init__(self, **_):
super().__init__(**_)
def __init__(self, type: str, **_):
type = f"Poke:{type}"
super().__init__(type=type, **_)
class Forward(BaseMessageComponent):
@@ -338,14 +341,14 @@ class Forward(BaseMessageComponent):
def __init__(self, **_):
super().__init__(**_)
class Node(BaseMessageComponent): # 该 component 仅支持使用 sendGroupForwardMessage 发送
class Node(BaseMessageComponent):
'''群合并转发消息'''
type: ComponentType = "Node"
id: T.Optional[int] = 0
name: T.Optional[str] = ""
uin: T.Optional[int] = 0
content: T.Optional[T.Union[str, list]] = ""
seq: T.Optional[T.Union[str, list]] = "" # 不清楚是什么
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[int] = 0 # qq号
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0
def __init__(self, content: T.Union[str, list], **_):
@@ -415,6 +418,17 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""
class File(BaseMessageComponent):
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url本地路径
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
"plain": Plain,
@@ -441,5 +455,6 @@ ComponentTypes = {
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
"unknown": Unknown,
'file': File,
}

View File

@@ -13,12 +13,10 @@ class MessageChain():
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''添加一条文本消息到消息链 `chain` 中。
@@ -77,16 +75,6 @@ class MessageChain():
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''设置是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
Note:
具体的效果以各适配器实现为准。
'''
self.is_split_ = is_split
return self
class EventResultType(enum.Enum):
'''用于描述事件处理的结果类型。
@@ -97,7 +85,14 @@ class EventResultType(enum.Enum):
'''
CONTINUE = enum.auto()
STOP = enum.auto()
class ResultContentType(enum.Enum):
'''用于描述事件结果的内容的类型。
'''
LLM_RESULT = enum.auto()
'''调用 LLM 产生的结果'''
GENERAL_RESULT = enum.auto()
'''普通的消息结果'''
@dataclass
class MessageEventResult(MessageChain):
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
@@ -106,12 +101,13 @@ class MessageEventResult(MessageChain):
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
`result_type` (EventResultType): 事件处理的结果类型。
'''
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
def stop_event(self) -> 'MessageEventResult':
'''终止事件传播。
'''
@@ -130,5 +126,24 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
result_type (EventResultType): 事件处理的结果类型。
'''
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
'''
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
CommandResult = MessageEventResult

View File

@@ -2,7 +2,9 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
@@ -10,8 +12,9 @@ from .respond.stage import RespondStage
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitCheckStage", # 检查会话是否超过频率限制
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
@@ -20,7 +23,9 @@ STAGES_ORDER = [
__all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",

View File

@@ -17,11 +17,14 @@ class ContentSafetyCheckStage(Stage):
config = ctx.astrbot_config['content_safety']
self.strategy_selector = StrategySelector(config)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]:
'''检查内容安全'''
ok, info = self.strategy_selector.check(event.get_message_str())
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
if not ok:
event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。"))
if event.is_at_or_wake_command:
event.set_result(MessageEventResult().message("你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。"))
yield
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return

View File

@@ -1,6 +1,6 @@
from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger
class StrategySelector:
def __init__(self, config: dict) -> None:
@@ -15,7 +15,8 @@ class StrategySelector:
try:
from .baidu_aip import BaiduAipStrategy
except ImportError:
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
logger.warning("使用百度内容审核应该先 pip install baidu-aip")
return
self.enabled_strategies.append(
BaiduAipStrategy(
config["baidu_aip"]["app_id"],

View File

@@ -0,0 +1,70 @@
import traceback
import asyncio
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Record, Image
@register_stage
class PreProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
self.platform_settings: dict = self.config.get('platform_settings', {})
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''在处理事件之前的预处理'''
# 路径映射
if mappings := self.platform_settings.get('path_mapping', []):
# 支持 RecordImage 消息段的路径映射。
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, (Record, Image)) and component.url:
for mapping in mappings:
from_, to_ = mapping.split(":")
from_ = from_.removesuffix("/")
to_ = to_.removesuffix("/")
url = component.url.removeprefix("file://")
if url.startswith(from_):
component.url = url.replace(from_, to_, 1)
logger.debug(f"路径映射: {url} -> {component.url}")
message_chain[idx] = component
# STT
if self.stt_settings.get('enable', False):
# TODO: 独立
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break

View File

@@ -1,114 +1,159 @@
'''
本地 Agent 模式的 LLM 调用 Stage
'''
import traceback
import inspect
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.star.star import star_map
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
self.ctx = ctx
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
# Chat 唤醒前缀
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
event.message_str = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
if self.prompt_prefix:
event.message_str = self.prompt_prefix + event.message_str
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
event.message_str = user_info + event.message_str
image_urls = []
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
image_urls.append(image_url)
tools = self.ctx.plugin_manager.context.get_llm_tool_manager()
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
try:
llm_response = await provider.text_chat(
prompt=event.message_str,
session_id=event.session_id,
image_urls=image_urls,
func_tool=tools
)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if provider is None:
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix):]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
req.session_id = event.unified_msg_origin
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
if not req.prompt and not req.image_urls:
return
# 执行请求 LLM 前事件钩子。
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
try:
logger.debug(f"提供商请求 Payload: {req}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text))
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'err':
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"))
elif llm_response.role == 'tool':
# function calling
function_calling_result = {}
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = tools.get_func(func_tool_name)
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
# 尝试调用工具函数
star_cls_obj = star_map.get(func_tool.module_name).star_cls
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
if hasattr(func_tool.func_obj, '__self__'):
# 猜测没有通过装饰器去注册
try:
ready_to_call = func_tool.func_obj(event, **func_tool_args)
except TypeError:
# 向下兼容
ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
else:
ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None无返回值
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
elif not event.is_stopped and not event.get_result():
continue
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
async for resp in wrapper:
if resp is not None: # 有 return 返回
function_calling_result[func_tool_name] = resp
else:
yield
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
if function_calling_result:
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
# 我们重新执行一遍这个 stage
req.func_tool = None # 暂时不支持递归工具调用
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
for tool_name, tool_result in function_calling_result.items():
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
req.prompt += extra_prompt
async for _ in self.process(event, _nested=True):
yield
else:
if llm_response.completion_text:
event.set_result(MessageEventResult().message(llm_response.completion_text))
except BaseException:
logger.error(traceback.format_exc())
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
return
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
return
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {
"role": "user",
"content": req.prompt
}
contexts.append(new_record)
contexts.append({
"role": "assistant",
"content": llm_response.completion_text
})
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=contexts_to_save
)

View File

@@ -1,13 +1,16 @@
'''
本地 Agent 模式的 AstrBot 插件调用 Stage
'''
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.star.star import star_map
import traceback
import inspect
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
@@ -24,59 +27,23 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_str not in star_map:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
star_cls_obj = star_map.get(handler.handler_module_str).star_cls
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
if hasattr(handler.handler, '__self__'):
# 猜测没有通过装饰器去注册
try:
ready_to_call = handler.handler(event, **params)
except TypeError:
# 向下兼容
ready_to_call = handler.handler(event, self.ctx.plugin_manager.context, **params)
else:
ready_to_call = handler.handler(star_cls_obj, event, **params)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None无返回值
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 插件主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
elif not event.is_stopped and not event.get_result():
continue
else:
yield
logger.debug(f"执行插件 handler {handler.handler_full_name}")
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
event.stop_event()

View File

@@ -5,6 +5,8 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core import logger
@register_stage
class ProcessStage(Stage):
@@ -23,14 +25,35 @@ class ProcessStage(Stage):
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
# 有插件 Handler 被激活
if activated_handlers:
async for _ in self.star_request_sub_stage.process(event):
yield
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
if not event._has_send_oper:
'''当没有发送操作'''
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
async for resp in self.star_request_sub_stage.process(event):
# 生成器返回值处理
if isinstance(resp, ProviderRequest):
# Handler 的 LLM 请求
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
yield
_t = True
yield
if not _t:
yield
else:
yield
# 调用 LLM 相关请求
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if not event._has_send_oper and event.is_at_or_wake_command:
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
async for _ in self.llm_request_sub_stage.process(event):
yield

View File

@@ -61,11 +61,12 @@ class RateLimitStage(Stage):
stall_duration = (next_window_time - now).total_seconds()
match self.rl_strategy:
case RateLimitStrategy.STALL:
case RateLimitStrategy.STALL.value:
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
await asyncio.sleep(stall_duration)
case RateLimitStrategy.DISCARD:
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
case RateLimitStrategy.DISCARD.value:
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。")
return event.stop_event()
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))

View File

@@ -1,20 +1,97 @@
import random
import asyncio
import math
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.message.components import Plain, Reply, At
@register_stage
class RespondStage:
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
# 分段回复
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.interval_method = ctx.astrbot_config['platform_settings']['segmented_reply']['interval_method']
self.log_base = float(ctx.astrbot_config['platform_settings']['segmented_reply']['log_base'])
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
logger.info(f"分段回复间隔时间:{self.interval}")
async def _word_cnt(self, text: str) -> int:
'''分段回复 统计字数'''
if all(ord(c) < 128 for c in text):
word_count = len(text.split())
else:
word_count = len([c for c in text if c.isalnum()])
return word_count
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
'''分段回复 计算间隔时间'''
if self.interval_method == 'log':
if isinstance(comp, Plain):
wc = await self._word_cnt(comp.text)
i = math.log(wc + 1, self.log_base)
return random.uniform(i, i + 0.5)
else:
return random.uniform(1, 1.75)
else:
# random
return random.uniform(self.interval[0], self.interval[1])
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return
if len(result.chain) > 0:
await event.send(result)
await event._pre_send()
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
# 分段回复
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
await event.send(MessageChain([*decorated_comps, comp]))
else:
await event.send(result)
await event._post_send()
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
try:
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
event.clear_result()

View File

@@ -1,38 +1,143 @@
import time
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@register_stage
class ResultDecorateStage:
class ResultDecorateStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.t2i = ctx.astrbot_config['t2i']
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.t2i_word_threshold = ctx.astrbot_config['t2i_word_threshold']
try:
self.t2i_word_threshold = int(self.t2i_word_threshold)
if self.t2i_word_threshold < 50:
self.t2i_word_threshold = 50
except BaseException:
self.t2i_word_threshold = 150
self.forward_threshold = ctx.astrbot_config['platform_settings']['forward_threshold']
# 分段回复
self.words_count_threshold = int(ctx.astrbot_config['platform_settings']['segmented_reply']['words_count_threshold'])
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
self.content_cleanup_rule = ctx.astrbot_config['platform_settings']['segmented_reply']['content_cleanup_rule']
# exception
self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response']
self.content_safe_check_stage = None
if self.content_safe_check_reply:
for stage in registered_stages:
if stage.__class__.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return
# 回复时检查内容安全
if self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result():
text = ""
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(event, check_text=text):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
try:
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
# 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result()
if result is None:
return
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
result.chain.insert(0, Plain(self.reply_prefix))
for comp in result.chain:
if isinstance(comp, Plain):
comp.text = self.reply_prefix + comp.text
break
# 分段回复
if self.enable_segmented_reply:
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
if len(comp.text) > self.words_count_threshold:
# 不分段回复
new_chain.append(comp)
continue
split_response = []
for line in comp.text.split("\n"):
split_response.extend(re.findall(self.regex, line))
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
if self.content_cleanup_rule:
seg = re.sub(self.content_cleanup_rule, "", seg)
if seg.strip():
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
result.chain = new_chain
# TTS
if self.ctx.astrbot_config['provider_tts_settings']['enable'] and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + comp.text)
audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
else:
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
new_chain.append(comp)
except BaseException:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
# 文本转图片
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
plain_str += "\n\n" + comp.text
else:
break
if plain_str and len(plain_str) > 150:
if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time()
try:
url = await html_renderer.render_t2i(plain_str, return_url=True)
@@ -42,4 +147,34 @@ class ResultDecorateStage:
if time.time() - render_start > 3:
logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
if url:
result.chain = [Image.fromURL(url)]
result.chain = [Image.fromURL(url)]
# 触发转发消息
has_forwarded = False
if event.get_platform_name() == 'aiocqhttp':
word_cnt = 0
for comp in result.chain:
if isinstance(comp, Plain):
word_cnt += len(comp.text)
if word_cnt > self.forward_threshold:
node = Node(
uin=event.get_self_id(),
name="AstrBot",
content=[
*result.chain
]
)
result.chain = [node]
has_forwarded = True
if not has_forwarded:
# at 回复
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
if len(result.chain) > 1 and isinstance(result.chain[1], Plain):
result.chain[1].text = "\n" + result.chain[1].text
# 引用回复
if self.reply_with_quote:
if not any(isinstance(item, File) for item in result.chain):
result.chain.insert(0, Reply(id=event.message_obj.message_id))

View File

@@ -12,14 +12,14 @@ class PipelineScheduler():
async def initialize(self):
for stage in registered_stages:
logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i]
logger.debug(f"执行阶段 {stage.__class__ .__name__}")
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coro = stage.process(event)
if isinstance(coro, AsyncGenerator):
async for _ in coro:
@@ -41,4 +41,8 @@ class PipelineScheduler():
async def execute(self, event: AstrMessageEvent):
'''执行 pipeline'''
await self._process_stages(event)
if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")

View File

@@ -1,8 +1,11 @@
from __future__ import annotations
import abc
from typing import List, AsyncGenerator, Union
import inspect
from astrbot.api import logger
from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = []
'''维护了所有已注册的 Stage 实现类'''
@@ -29,4 +32,37 @@ class Stage(abc.ABC):
'''
raise NotImplementedError
async def _call_handler(
self,
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
'''调用 Handler。'''
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as e:
# 向下兼容
logger.debug(str(e))
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
async for ret in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None无返回值
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret

View File

@@ -2,11 +2,12 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.message.components import At
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At, Reply
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
@register_stage
class WakingCheckStage(Stage):
@@ -21,6 +22,9 @@ class WakingCheckStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True
)
async def process(
self, event: AstrMessageEvent
@@ -47,6 +51,7 @@ class WakingCheckStage(Stage):
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
break
is_wake = True
event.is_at_or_wake_command = True
event.is_wake = True
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
@@ -60,54 +65,51 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
wake_prefix = ""
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
wake_prefix = ""
# 检查插件的 handler filter
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry:
# filter 需要满足 AND 的逻辑关系
passed = True
child_command_handler_md = None
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
if len(handler.event_filters) == 0:
# 不可能有这种情况, 也不允许有这种情况
continue
for filter in handler.event_filters:
try:
if isinstance(filter, CommandGroupFilter):
"""如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata"""
ok, child_command_handler_md = filter.filter(
event, self.ctx.astrbot_config
)
if not ok:
passed = False
else:
handler = child_command_handler_md # handler 覆盖
break
if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
break
except Exception as e:
# event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
# yield
await event.send(
MessageEventResult().message(
f"插件 {handler.handler_full_name} 报错:{e}"
f"插件 {star_map[handler.handler_module_path].name}: {e}"
)
)
event.stop_event()
passed = False
break
if passed:
if permission_not_pass:
if self.no_permission_reply:
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"))
event.stop_event()
return
is_wake = True
event.is_wake = True
@@ -116,6 +118,7 @@ class WakingCheckStage(Stage):
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
event.clear_extra()
event.set_extra("activated_handlers", activated_handlers)

View File

@@ -10,12 +10,25 @@ class WhitelistCheckStage(Stage):
'''检查是否在群聊/私聊白名单
'''
async def initialize(self, ctx: PipelineContext) -> None:
self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
# 白名单检查未启用
return
if len(self.whitelist) == 0:
# 白名单为空,不检查
return
if event.get_platform_name() == 'webchat':
# WebChat 豁免
return
# 检查是否在白名单
if self.wl_ignore_admin_on_group:
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:

View File

@@ -7,6 +7,8 @@ from astrbot.core.platform.message_type import MessageType
from typing import List, Union
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.db.po import Conversation
@dataclass
class MessageSesion:
@@ -29,11 +31,19 @@ class AstrMessageEvent(abc.ABC):
platform_meta: PlatformMetadata,
session_id: str,):
self.message_str = message_str
'''纯文本的消息'''
self.message_obj = message_obj
'''消息对象AstrBotMessage。带有完整的消息结构。'''
self.platform_meta = platform_meta
'''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp'''
self.session_id = session_id
'''用户的会话 ID。可以直接使用下面的 unified_msg_origin'''
self.role = "member"
self.is_wake = False
'''用户是否是管理员。如果是管理员,这里是 admin'''
self.is_wake = False # 是否通过 WakingStage
'''是否唤醒'''
self.is_at_or_wake_command = False
'''是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True但是不会让这个属性置为 True'''
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
@@ -41,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
session_id=session_id
)
self.unified_msg_origin = str(self.session)
'''统一的消息来源字符串。格式为 platform_name:message_type:session_id'''
self._result: MessageEventResult = None
'''消息事件的结果'''
@@ -176,6 +186,15 @@ class AstrMessageEvent(abc.ABC):
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
'''调度器会在执行 send() 前调用该方法'''
pass
async def _post_send(self):
'''调度器会在执行 send() 后调用该方法'''
pass
def set_result(self, result: Union[MessageEventResult, str]):
'''设置消息事件的结果。
@@ -236,7 +255,9 @@ class AstrMessageEvent(abc.ABC):
清除消息事件的结果。
'''
self._result = None
'''消息链相关'''
def make_result(self) -> MessageEventResult:
'''
创建一个空的消息事件结果。
@@ -275,4 +296,40 @@ class AstrMessageEvent(abc.ABC):
'''
mer = MessageEventResult()
mer.chain = chain
return mer
return mer
'''LLM 请求相关'''
def request_llm(
self,
prompt: str,
func_tool_manager = None,
session_id: str = None,
image_urls: List[str] = [],
contexts: List = [],
system_prompt: str = "",
conversation: Conversation = None
) -> ProviderRequest:
'''
创建一个 LLM 请求。
Examples:
```py
yield event.request_llm(prompt="hi")
```
prompt: 提示词
session_id: 已经过时,留空即可
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。
'''
return ProviderRequest(
prompt = prompt,
session_id = session_id,
image_urls = image_urls,
func_tool = func_tool_manager,
contexts = contexts,
system_prompt = system_prompt,
conversation=conversation
)

View File

@@ -4,7 +4,7 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue):
@@ -15,16 +15,25 @@ class PlatformManager():
self.settings = config['platform_settings']
self.event_queue = event_queue
for platform in self.platforms_config:
if not platform['enable']:
continue
match platform['type']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat":
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
try:
for platform in self.platforms_config:
if not platform['enable']:
continue
match platform['type']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401
case "qq_official":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "qq_official_webhook":
from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。")
except Exception as e:
logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}")
async def initialize(self):
for platform in self.platforms_config:
@@ -34,9 +43,11 @@ class PlatformManager():
logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = platform_cls_map[platform['type']]
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
def get_insts(self):
return self.platform_insts

View File

@@ -1,5 +1,12 @@
from dataclasses import dataclass
@dataclass
class PlatformMetadata():
name: str # 平台的名称
description: str # 平台的描述
name: str
'''平台的名称'''
description: str
'''平台的描述'''
default_config_tmpl: dict = None
'''平台的默认配置模板'''
adapter_display_name: str = None
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''

View File

@@ -7,15 +7,34 @@ platform_registry: List[PlatformMetadata] = []
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(adapter_name: str, desc: str):
'''用于注册平台适配器的带参装饰器'''
def register_platform_adapter(
adapter_name: str,
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None
):
'''用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
'''
def decorator(cls):
if adapter_name in platform_cls_map:
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = adapter_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
if 'id' not in default_config_tmpl:
default_config_tmpl['id'] = adapter_name
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls

View File

@@ -1,9 +1,7 @@
import os
import random
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Record, At, Node, Music, Video
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -20,27 +18,43 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
d = segment.toDict()
if isinstance(segment, Plain):
d['type'] = 'text'
if isinstance(segment, Image):
elif isinstance(segment, (Image, Record)):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
image_base64 = file_to_base64(segment.file[8:])
bs64_data = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
image_base64 = file_to_base64(image_file_path)
d['data']['file'] = image_base64
bs64_data = file_to_base64(image_file_path)
elif segment.file and segment.file.startswith("base64://"):
bs64_data = segment.file
else:
bs64_data = file_to_base64(segment.file)
d['data'] = {
'file': bs64_data,
}
elif isinstance(segment, At):
d['data'] = {
'qq': str(segment.qq) # 转换为字符串
}
ret.append(d)
return ret
async def send(self, message: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if os.environ.get('TEST_MODE', 'off') == 'on':
return
if message.is_split_: # 分条发送
for m in ret:
await self.bot.send(self.message_obj.raw_message, [m])
await asyncio.sleep(random.uniform(0.75, 2.5))
send_one_by_one = False
for seg in message.chain:
if isinstance(seg, (Node, Music)):
# 转发消息不能和普通消息混在一起发送
send_one_by_one = True
break
if send_one_by_one:
for seg in message.chain:
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
await asyncio.sleep(0.5)
else:
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)

View File

@@ -1,16 +1,20 @@
import os
import time
import asyncio
import logging
import uuid
from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
from astrbot.core.utils.io import download_file
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
@@ -42,20 +46,83 @@ class AiocqhttpAdapter(Platform):
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
def convert_message(self, event: Event) -> AstrBotMessage:
async def convert_message(self, event: Event) -> AstrBotMessage:
logger.debug(f"[aiocqhttp] RawMessage {event}")
if event['post_type'] == 'message':
abm = await self._convert_handle_message_event(event)
elif event['post_type'] == 'notice':
abm = await self._convert_handle_notice_event(event)
elif event['post_type'] == 'request':
abm = await self._convert_handle_request_event(event)
return abm
async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage:
'''OneBot V11 请求类事件'''
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.tag = "aiocqhttp"
abm.sender = MessageMember(
user_id=event.user_id,
nickname=event.user_id
)
abm.type = MessageType.OTHER_MESSAGE
if 'group_id' in event and event['group_id']:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.sender.user_id + "_" + str(event.group_id)
abm.message_str = ''
abm.message = []
abm.timestamp = int(time.time())
abm.message_id = uuid.uuid4().hex
abm.raw_message = event
return abm
async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:
'''OneBot V11 通知类事件'''
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
user_id=event.user_id,
nickname=event.user_id
)
abm.type = MessageType.OTHER_MESSAGE
if 'group_id' in event and event['group_id']:
abm.group_id = str(event.group_id)
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id
else:
abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id
abm.message_str = ""
abm.message = []
abm.raw_message = event
abm.timestamp = int(time.time())
abm.message_id = uuid.uuid4().hex
abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname'])
if 'sub_type' in event:
if event['sub_type'] == 'poke' and 'target_id' in event:
abm.message.append(Poke(qq=str(event['target_id']), type='poke')) # noqa: F405
return abm
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
'''OneBot V11 消息类事件'''
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname'])
if event['message_type'] == 'group':
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
elif event['message_type'] == 'private':
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session:
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id
else:
abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id
@@ -72,32 +139,82 @@ class AiocqhttpAdapter(Platform):
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return
logger.debug(f"aiocqhttp: 收到消息: {event.message}")
# 按消息段类型类型适配
for m in event.message:
t = m['type']
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data'])
elif t == 'file':
if m['data']['url'] and m['data']['url'].startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m['data'].get('file_name', "file")
path = os.path.join("data/temp", file_name)
await download_file(m['data']['url'], path)
m['data'] = {
"file": path,
"name": file_name
}
else:
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = event
return abm
def run(self) -> Awaitable[Any]:
if not self.host or not self.port:
return
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port将使用默认值http://0.0.0.0:6199")
self.host = "0.0.0.0"
self.port = 6199
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_request()
async def request(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_notice()
async def notice(event: Event):
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)

View File

@@ -0,0 +1,445 @@
import threading
import asyncio
import aiohttp
import quart
import base64
import datetime
import re
import os
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api import logger, sp
from .downloader import GeweDownloader
from astrbot.core.utils.io import download_image_by_url
class SimpleGewechatClient():
'''针对 Gewechat 的简单实现。
@author: Soulter
@website: https://github.com/Soulter
'''
def __init__(self, base_url: str, nickname: str, host: str, port: int, event_queue: asyncio.Queue):
self.base_url = base_url
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
self.token = None
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.server = quart.Quart(__name__)
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
self.userrealnames = {}
async def get_token_id(self):
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json()
self.token = json_blob['data']
logger.info(f"获取到 Gewechat Token: {self.token}")
self.headers = {
"X-GEWE-TOKEN": self.token
}
async def _convert(self, data: dict) -> AstrBotMessage:
type_name = data['TypeName']
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
if 'Data' in data and 'CreateTime' in data['Data']:
# 得到系统 UTF+8 的 ts
tz_offset = datetime.timedelta(hours=8)
tz = datetime.timezone(tz_offset)
ts = datetime.datetime.now(tz).timestamp()
create_time = data['Data']['CreateTime']
if create_time < ts - 30:
logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}")
return
abm = AstrBotMessage()
d = data['Data']
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
abm.message_id = str(d.get('MsgId'))
abm.session_id = from_user_name
abm.self_id = data['Wxid'] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
# content = content.split('\u2005')[1]
content = re.sub(r'@[^\u2005]*\u2005', '', content)
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
if '在群聊中@了你' in d.get('PushContent', ''):
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.message = []
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
# 解析用户真实名字
user_real_name = "unknown"
if abm.group_id:
if abm.group_id not in self.userrealnames or user_id not in self.userrealnames[abm.group_id]:
# 获取群成员列表,并且缓存
if abm.group_id not in self.userrealnames:
self.userrealnames[abm.group_id] = {}
member_list = await self.get_chatroom_member_list(abm.group_id)
logger.debug(f"获取到 {abm.group_id} 的群成员列表。")
if member_list and 'memberList' in member_list:
for member in member_list['memberList']:
self.userrealnames[abm.group_id][member['wxid']] = member['nickName']
if user_id in self.userrealnames[abm.group_id]:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
user_real_name = self.userrealnames[abm.group_id][user_id]
else:
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0]
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
# 不同消息类型
match d['MsgType']:
case 1:
# 文本消息
abm.message.append(Plain(content))
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid,
content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
with open(file_path, "wb") as f:
f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
case _:
logger.info(f"未实现的消息类型: {d['MsgType']}")
abm.raw_message = d
logger.debug(f"abm: {abm}")
return abm
async def callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
if data.get('testMsg', None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}")
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={
"token": self.token,
"callbackUrl": self.callback_url
}
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob['ret'] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。")
async def start_polling(self):
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host='0.0.0.0',
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed:
await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。")
async def check_online(self, appid: str):
# /login/checkOnline
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkOnline",
headers=self.headers,
json={
"appId": appid
}
) as resp:
json_blob = await resp.json()
return json_blob['data']
async def logout(self):
if self.appid:
online = await self.check_online(self.appid)
if online:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/logout",
headers=self.headers,
json={
"appId": self.appid
}
) as resp:
json_blob = await resp.json()
logger.info(f"登出结果: {json_blob}")
async def login(self):
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
if self.appid:
online = await self.check_online(self.appid)
if online:
logger.info(f"APPID: {self.appid} 已在线")
return
payload = {
"appId": self.appid
}
if self.appid:
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/getLoginQrCode",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
if json_blob['ret'] != 200:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob['data']['qrData']
qr_uuid = json_blob['data']['uuid']
appid = json_blob['data']['appId']
logger.info(f"APPID: {appid}")
logger.warning(f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}")
# 执行登录
retry_cnt = 64
payload.update({
"uuid": qr_uuid,
"appId": appid
})
verify_flag = False
while retry_cnt > 0:
retry_cnt -= 1
# 需要验证码
if verify_flag or os.path.exists("data/temp/gewe_code"):
with open("data/temp/gewe_code", "r") as f:
code = f.read().strip()
if not code:
logger.warning("未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456")
await asyncio.sleep(5)
continue
payload['captchCode'] = code
logger.info(f"使用验证码: {code}")
try:
os.remove("data/temp/gewe_code")
except:
logger.warning("删除验证码文件 data/temp/gewe_code 失败。")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkLogin",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.info(f"检查登录状态: {json_blob}")
ret = json_blob['ret']
msg = ''
if json_blob['data'] and 'msg' in json_blob['data']:
msg = json_blob['data']['msg']
if ret == 500 and '安全验证码' in msg:
logger.warning("此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456")
verify_flag = True
else:
status = json_blob['data']['status']
nickname = json_blob['data'].get('nickName', '')
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{self.nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
'''API'''
async def get_chatroom_member_list(self, chatroom_wxid: str):
payload = {
"appId": self.appid,
"chatroomId": chatroom_wxid
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
return json_blob['data']
async def post_text(self, to_wxid, content: str, ats: str = ""):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"content": content,
}
if ats:
payload['ats'] = ats
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postText",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
async def post_file(self, to_wxid, file_url: str, file_name: str):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"fileUrl": file_url,
"fileName": file_name
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postFile",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送文件结果: {json_blob}")

View File

@@ -0,0 +1,51 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader():
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {
"Content-Type": "application/json",
"X-GEWE-TOKEN": token
}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}",
headers=self.headers,
json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {
"appId": appid,
"xml": xml,
"msgId": msg_id
}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
'''返回一个可下载的 URL'''
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {
"appId": appid,
"xml": xml,
"type": choice
}
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
json_blob = json.loads(data)
if 'fileUrl' in json_blob['data']:
return self.download_base_url + json_blob['data']['fileUrl']
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")

View File

@@ -0,0 +1,139 @@
import wave
import uuid
import traceback
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record, At, File
from .client import SimpleGewechatClient
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
elif n_frames == 0:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleGewechatClient
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(message: MessageChain, user_name: str):
pass
async def send(self, message: MessageChain):
to_wxid = self.message_obj.raw_message.get('to_wxid', None)
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
# 检查@
ats = []
ats_names = []
for comp in message.chain:
if isinstance(comp, At):
ats.append(comp.qq)
ats_names.append(comp.name)
has_at = False
for comp in message.chain:
if isinstance(comp, Plain):
text = comp.text
payload = {
"to_wxid": to_wxid,
"content": text,
}
if not has_at and ats:
ats = f"{','.join(ats)}"
ats_names = f"@{' @'.join(ats_names)}"
text = f"{ats_names} {text}"
payload["content"] = text
payload["ats"] = ats
has_at = True
await self.client.post_text(**payload)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
temp_directory = os.path.abspath('data/temp')
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await self.client.post_image(to_wxid, img_url)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
try:
duration = await wav_to_tencent_silk(record_path, silk_path)
except Exception as e:
logger.error(traceback.format_exc())
await self.send(MessageChain().message(f"语音文件转换失败。{str(e)}"))
logger.info("Silk 语音文件格式转换至: " + record_path)
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback record url: {record_url}")
await self.client.post_voice(to_wxid, record_url, duration*1000)
elif isinstance(comp, File):
file_path = comp.file
file_name = comp.name
if file_path.startswith("file:///"):
file_path = file_path[8:]
elif file_path.startswith("http"):
await download_file(file_path, f"data/temp/{file_name}")
else:
file_path = file_path
file_id = os.path.basename(file_path)
file_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback file url: {file_url}")
await self.client.post_file(to_wxid, file_url, file_id)
elif isinstance(comp, At):
pass
else:
logger.debug(f"gewechat 忽略: {comp.type}")
await super().send(message)

View File

@@ -0,0 +1,89 @@
import sys
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient
from astrbot.core.message.components import Plain
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
class GewechatPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
self.client = None
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
to_wxid = session.session_id
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
for comp in message_chain.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"gewechat",
"基于 gewechat 的 Wechat 适配器",
)
@override
def run(self):
self.client = SimpleGewechatClient(
self.config['base_url'],
self.config['nickname'],
self.config['host'],
self.config['port'],
self._event_queue,
)
async def on_event_received(abm: AstrBotMessage):
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
return self._run()
async def logout(self):
await self.client.logout()
async def _run(self):
await self.client.login()
await self.client.start_polling()
async def handle_msg(self, message: AstrBotMessage):
if message.type == MessageType.GROUP_MESSAGE:
if self.settingss['unique_session']:
message.session_id = message.sender.user_id + "_" + message.group_id
message_event = GewechatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client
)
self.commit_event(message_event)

View File

@@ -0,0 +1,175 @@
import base64
import time
import asyncio
import json
import re
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .lark_event import LarkMessageEvent
from ...register import register_platform_adapter
from astrbot.core.message.components import BaseMessageComponent
from astrbot import logger
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
@register_platform_adapter("lark", "飞书机器人官方 API 适配器")
class LarkPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.unique_session = platform_settings['unique_session']
self.appid = platform_config['app_id']
self.appsecret = platform_config['app_secret']
self.domain = platform_config.get('domain', lark.FEISHU_DOMAIN)
self.bot_name = platform_config.get('lark_bot_name', "astrbot")
if not self.bot_name:
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
await self.convert_msg(event)
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1):
asyncio.create_task(on_msg_event_recv(event))
self.event_handler = lark.EventDispatcherHandler.builder("", "") \
.register_p2_im_message_receive_v1(do_v2_msg_event) \
.build()
self.client = lark.ws.Client(
app_id=self.appid,
app_secret=self.appsecret,
log_level=lark.LogLevel.ERROR,
domain=self.domain,
event_handler=self.event_handler
)
self.lark_api = (
lark.Client.builder()
.app_id(self.appid)
.app_secret(self.appsecret)
.build()
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"lark",
"飞书机器人官方 API 适配器",
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
message = event.event.message
abm = AstrBotMessage()
abm.timestamp = int(message.create_time) / 1000
abm.message = []
abm.type = MessageType.GROUP_MESSAGE if message.chat_type == 'group' else MessageType.FRIEND_MESSAGE
if message.chat_type == 'group':
abm.group_id = message.chat_id
abm.self_id = self.bot_name
abm.message_str = ""
at_list = {}
if message.mentions:
for m in message.mentions:
at_list[m.key] = At(qq=m.id.open_id, name=m.name)
if m.name == self.bot_name:
abm.self_id = m.id.open_id
content_json_b = json.loads(message.content)
if message.message_type == 'text':
message_str_raw = content_json_b['text'] # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本去掉AT符号部分
parts = re.split(at_pattern, message_str_raw)
for i in range(len(parts)):
s = parts[i].strip()
if not s:
continue
if s in at_list:
abm.message.append(at_list[s])
else:
abm.message.append(Plain(parts[i].strip()))
elif message.message_type == 'post':
_ls = []
content_ls = content_json_b.get('content', [])
for comp in content_ls:
if isinstance(comp, list):
_ls.extend(comp)
elif isinstance(comp, dict):
_ls.append(comp)
content_json_b = _ls
elif message.message_type == 'image':
content_json_b = [
{"tag": "img", "image_key": content_json_b["image_key"], "style": []}
]
if message.message_type in ('post', 'image'):
for comp in content_json_b:
if comp['tag'] == 'at':
abm.message.append(at_list[comp['user_id']])
elif comp['tag'] == 'text' and comp['text'].strip():
abm.message.append(Plain(comp['text'].strip()))
elif comp['tag'] == 'img':
image_key = comp['image_key']
request = GetMessageResourceRequest.builder() \
.message_id(message.message_id) \
.file_key(image_key) \
.type("image") \
.build()
response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success():
logger.error(f"无法下载飞书图片: {image_key}")
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Image.fromBase64(image_base64))
for comp in abm.message:
if isinstance(comp, Plain):
abm.message_str += comp.text
abm.message_id = message.message_id
abm.raw_message = message
abm.sender = MessageMember(
user_id=event.event.sender.sender_id.open_id,
nickname=event.event.sender.sender_id.open_id[:8]
)
# 独立会话
if not self.unique_session:
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
else:
abm.session_id = abm.sender.user_id
logger.debug(abm)
await self.handle_msg(abm)
async def handle_msg(self, abm: AstrBotMessage):
event = LarkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
bot=self.lark_api
)
self._event_queue.put_nowait(event)
async def run(self):
# self.client.start()
await self.client._connect()

View File

@@ -0,0 +1,96 @@
import json
import uuid
import lark_oapi as lark
from typing import List
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image as AstrBotImage, Record, At, Node, Music, Video
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from lark_oapi.api.im.v1 import *
from astrbot import logger
class LarkMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id, bot: lark.Client):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@staticmethod
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List:
ret = []
_stage = []
for comp in message.chain:
if isinstance(comp, Plain):
_stage.append({
"tag": "md",
"text": comp.text
})
elif isinstance(comp, At):
_stage.append({
"tag": "at",
"user_id": comp.qq,
"style": []
})
elif isinstance(comp, AstrBotImage):
file_path = ""
if comp.file and comp.file.startswith("file:///"):
file_path = comp.file.replace('file:///', '')
elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path
elif comp.file and comp.file.startswith("base64://"):
pass
else:
file_path = comp.file
request = CreateImageRequest.builder() \
.request_body( \
CreateImageRequestBody.builder() \
.image_type("message") \
.image(open(file_path, 'rb')) \
.build() \
) \
.build()
response = await lark_client.im.v1.image.acreate(request)
if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
image_key = response.data.image_key
print(image_key)
ret.append(_stage)
ret.append([{
"tag": "img",
"image_key": image_key
}])
_stage.clear()
else:
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
if _stage:
ret.append(_stage)
return ret
async def send(self, message: MessageChain):
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
wrapped = {
"zh_cn": {
"title": "",
"content": res,
}
}
request = ReplyMessageRequest.builder() \
.message_id(self.message_obj.message_id) \
.request_body( \
ReplyMessageRequestBody.builder() \
.content(json.dumps(wrapped)) \
.msg_type("post") \
.uuid(str(uuid.uuid4())) \
.reply_in_thread(False) \
.build() \
) \
.build()
response = await self.bot.im.v1.message.areply(request)
if not response.success():
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
await super().send(message)

View File

@@ -8,18 +8,30 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route
from astrbot.api import logger
class QQOfficialMessageEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
'''QQ 官方 API 仅支持回复一次'''
source = self.message_obj.raw_message
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
if not plain_text and not image_base64 and not image_path:
return
payload = {
'content': plain_text,
@@ -31,11 +43,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if image_path:
@@ -46,7 +60,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload['file_image'] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(message)
await super().send(self.send_buffer)
self.send_buffer = None
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
payload = {
@@ -73,9 +89,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:])
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
else:
image_base64 = file_to_base64(i.file).replace("base64://", "")
image_file_path = i.file
else:
logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import botpy
import logging
import time
@@ -11,7 +13,7 @@ from botpy import Client
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .qqofficial_message_event import QQOfficialMessageEvent
from ...register import register_platform_adapter
@@ -28,25 +30,25 @@ class botClient(Client):
# 收到群消息
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid
self._commit(abm)
# 收到频道消息
async def on_at_message_create(self, message: botpy.message.Message):
abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id
self._commit(abm)
# 收到私聊消息
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
abm.session_id = abm.sender.user_id
self._commit(abm)
# 收到 C2C 消息
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
abm.session_id = abm.sender.user_id
self._commit(abm)
@@ -102,7 +104,8 @@ class QQOfficialPlatformAdapter(Platform):
"QQ 机器人官方 API 适配器",
)
def _parse_from_qqofficial(self, message: Union[botpy.message.Message, botpy.message.GroupMessage],
@staticmethod
def _parse_from_qqofficial(message: Union[botpy.message.Message, botpy.message.GroupMessage],
message_type: MessageType):
abm = AstrBotMessage()
abm.type = message_type
@@ -111,6 +114,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.message_id = message.id
abm.tag = "qq_official"
msg: List[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
if isinstance(message, botpy.message.GroupMessage):
@@ -126,7 +130,7 @@ class QQOfficialPlatformAdapter(Platform):
)
abm.message_str = message.content.strip()
abm.self_id = "unknown_selfid"
msg.append(At(qq="qq_official"))
msg.append(Plain(abm.message_str))
if message.attachments:
for i in message.attachments:
@@ -146,7 +150,7 @@ class QQOfficialPlatformAdapter(Platform):
plain_content = message.content.replace(
"<@!"+str(abm.self_id)+">", "").strip()
msg.append(Plain(plain_content))
if message.attachments:
for i in message.attachments:
if i.content_type.startswith("image"):
@@ -161,11 +165,14 @@ class QQOfficialPlatformAdapter(Platform):
str(message.author.id),
str(message.author.username)
)
msg.append(At(qq="qq_official"))
msg.append(Plain(plain_content))
if isinstance(message, botpy.message.Message):
abm.group_id = message.channel_id
else:
raise ValueError(f"Unknown message type: {message_type}")
abm.self_id = "qq_official"
return abm
def run(self):

View File

@@ -0,0 +1,99 @@
import botpy
import logging
import asyncio
import botpy.message
import botpy.types
import botpy.types.message
from botpy import Client
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from .qo_webhook_event import QQOfficialWebhookMessageEvent
from ...register import register_platform_adapter
from .qo_webhook_server import QQOfficialWebhook
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
# remove logger handler
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
# QQ 机器人官方框架
class botClient(Client):
def set_platform(self, platform: 'QQOfficialWebhookPlatformAdapter'):
self.platform = platform
# 收到群消息
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid
self._commit(abm)
# 收到频道消息
async def on_at_message_create(self, message: botpy.message.Message):
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE)
abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id
self._commit(abm)
# 收到私聊消息
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
abm.session_id = abm.sender.user_id
self._commit(abm)
# 收到 C2C 消息
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
abm.session_id = abm.sender.user_id
self._commit(abm)
def _commit(self, abm: AstrBotMessage):
self.platform.commit_event(QQOfficialWebhookMessageEvent(
abm.message_str,
abm,
self.platform.meta(),
abm.session_id,
self
))
@register_platform_adapter("qq_official_webhook", "QQ 机器人官方 API 适配器(Webhook)")
class QQOfficialWebhookPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.appid = platform_config['appid']
self.secret = platform_config['secret']
self.unique_session = platform_settings['unique_session']
intents = botpy.Intents(
public_messages=True,
public_guild_messages=True,
direct_message=True
)
self.client = botClient(
intents=intents, # 已经无用
bot_log=False,
timeout=20,
)
self.client.set_platform(self)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"qq_official_webhook",
"QQ 机器人官方 API 适配器",
)
async def run(self):
self.webhook_helper = QQOfficialWebhook(
self.config,
self._event_queue,
self.client
)
await self.webhook_helper.initialize()
await self.webhook_helper.start_polling()

View File

@@ -0,0 +1,18 @@
import botpy
import botpy.message
import botpy.types
import botpy.types.message
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Reply
from botpy import Client
from botpy.http import Route
from astrbot.api import logger
from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
class QQOfficialWebhookMessageEvent(QQOfficialMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
super().__init__(message_str, message_obj, platform_meta, session_id, bot)

View File

@@ -0,0 +1,108 @@
import aiohttp
import quart
import json
import logging
import asyncio
import typing
from botpy import BotAPI, BotHttp, Client, Token, BotWebSocket, ConnectionSession
from astrbot.api import logger
import traceback
from cryptography.hazmat.primitives.asymmetric import ed25519
# remove logger handler
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
class QQOfficialWebhook():
def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client):
self.appid = config['appid']
self.secret = config['secret']
self.port = config.get("port", 6196)
if isinstance(self.port, str):
self.port = int(self.port)
self.http: BotHttp = BotHttp(timeout=300)
self.api: BotAPI = BotAPI(http=self.http)
self.token = Token(self.appid, self.secret)
self.server = quart.Quart(__name__)
self.server.add_url_rule('/astrbot-qo-webhook/callback', view_func=self.callback, methods=['POST'])
self.client = botpy_client
self.event_queue = event_queue
async def initialize(self):
logger.info(f"正在登录到 QQ 官方机器人...")
self.user = await self.http.login(self.token)
logger.info(f"已登录 QQ 官方机器人账号: {self.user}")
# 直接注入到 botpy 的 Client移花接木
self.client.api = self.api
self.client.http = self.http
async def bot_connect():
pass
self._connection = ConnectionSession(
max_async=1,
connect=bot_connect,
dispatch=self.client.ws_dispatch,
loop=asyncio.get_event_loop(),
api=self.api,
)
async def repeat_seed(self, bot_secret: str, target_size: int = 32) -> bytes:
seed = bot_secret
while len(seed) < target_size:
seed *= 2
return seed[:target_size].encode('utf-8')
async def webhook_validation(self, validation_payload: dict):
seed = await self.repeat_seed(self.secret)
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
msg = validation_payload.get("event_ts", "") + validation_payload.get("plain_token", "")
# sign
signature = private_key.sign(msg.encode()).hex()
response = {
"plain_token": validation_payload.get("plain_token"),
"signature": signature
}
return response
async def callback(self):
msg: dict = await quart.request.json
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
event = msg.get("t")
opcode = msg.get("op")
data = msg.get("d")
if opcode == 13:
# validation
signed = await self.webhook_validation(data)
print(signed)
return signed
if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
event = msg["t"].lower()
try:
func = self._connection.parser[event]
except KeyError:
logger.error("_parser unknown event %s.", event)
else:
func(msg)
return {"opcode": 12}
async def start_polling(self):
await self.server.run_task(
host='0.0.0.0',
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed:
await asyncio.sleep(1)
logger.info("qq_official_webhook 适配器已关闭。")

View File

@@ -1,44 +0,0 @@
import random
import asyncio
from astrbot.core.utils.io import download_image_by_url
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from vchat import Core
class VChatPlatformEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(client: Core, message: MessageChain, user_name: str):
plain = ""
for comp in message.chain:
if isinstance(comp, Plain):
if message.is_split_:
await client.send_msg(comp.text, user_name)
else:
plain += comp.text
elif isinstance(comp, Image):
if comp.file and comp.file.startswith("file:///"):
file_path = comp.file.replace("file:///", "")
with open(file_path, "rb") as f:
await client.send_image(user_name, fd=f)
elif comp.file and comp.file.startswith("http"):
image_path = await download_image_by_url(comp.file)
with open(image_path, "rb") as f:
await client.send_image(user_name, fd=f)
else:
logger.error(f"不支持的 vchat(微信适配器) 消息类型: {comp}")
await asyncio.sleep(random.uniform(0.5, 1.5)) # 🤓
if plain:
await client.send_msg(plain, user_name)
async def send(self, message: MessageChain):
await VChatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
await super().send(message)

View File

@@ -1,119 +0,0 @@
import sys
import time
import uuid
import asyncio
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import *
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from .vchat_message_event import VChatPlatformEvent
from ...register import register_platform_adapter
from vchat import Core
from vchat import model
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@register_platform_adapter("vchat", "基于 VChat 的 Wechat 适配器")
class VChatPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
self.client_self_id = uuid.uuid4().hex[:8]
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
from_username = session.session_id.split('$$')[0]
await VChatPlatformEvent.send_with_client(self.client, message_chain, from_username)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"vchat",
"基于 VChat 的 Wechat 适配器",
)
@override
def run(self):
self.client = Core()
@self.client.msg_register(msg_types=model.ContentTypes.TEXT,
contact_type=model.ContactTypes.CHATROOM | model.ContactTypes.USER)
async def _(msg: model.Message):
if isinstance(msg.content, model.UselessContent):
return
if msg.create_time < self.start_time:
logger.debug(f"忽略旧消息: {msg}")
return
logger.debug(f"收到消息: {msg.todict()}")
abmsg = self.convert_message(msg)
# await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞
asyncio.create_task(self.handle_msg(abmsg))
# TODO: 对齐微信服务器时间
self.start_time = int(time.time())
return self._run()
async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
await self.client.run()
def convert_message(self, msg: model.Message) -> AstrBotMessage:
# credits: https://github.com/z2z63/astrbot_plugin_vchat/blob/master/main.py#L49
assert isinstance(msg.content, model.TextContent)
amsg = AstrBotMessage()
amsg.message = [Plain(msg.content.content)]
amsg.self_id = self.client_self_id
if msg.content.is_at_me:
amsg.message.insert(0, At(qq=amsg.self_id))
sender = msg.chatroom_sender or msg.from_
amsg.sender = MessageMember(sender.username, sender.nickname)
if msg.content.is_at_me:
amsg.message_str = msg.content.content.split("\u2005")[1].strip()
else:
amsg.message_str = msg.content.content
amsg.message_id = msg.message_id
if isinstance(msg.from_, model.User):
amsg.type = MessageType.FRIEND_MESSAGE
elif isinstance(msg.from_, model.Chatroom):
amsg.type = MessageType.GROUP_MESSAGE
amsg.group_id = msg.from_.username
else:
logger.error(f"不支持的 Wechat 消息类型: {msg.from_}")
amsg.raw_message = msg
if self.settingss['unique_session']:
session_id = msg.from_.username + "$$" + msg.to.username
if msg.chatroom_sender is not None:
session_id += '$$' + msg.chatroom_sender.username
else:
session_id = msg.from_.username
amsg.session_id = session_id
return amsg
async def handle_msg(self, message: AstrBotMessage):
message_event = VChatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client
)
logger.info(f"处理消息: {message_event}")
self.commit_event(message_event)

View File

@@ -0,0 +1,112 @@
import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
from astrbot.api import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
self.callback = callback
async def run(self):
while True:
data = await self.queue.get()
await self.callback(data)
@register_platform_adapter("webchat", "webchat")
class WebChatAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings['unique_session']
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
# abm.session_id = f"webchat!{username}!{cid}"
plain = ""
cid = session.session_id.split("!")[-1]
for comp in message_chain.chain:
if isinstance(comp, Plain):
plain += comp.text
web_chat_back_queue.put_nowait((plain, cid))
await super().send_by_session(session, message_chain)
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
abm.message = []
if payload['message']:
abm.message.append(Plain(payload['message']))
if payload['image_url']:
if isinstance(payload['image_url'], list):
for img in payload['image_url']:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
else:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
if payload['audio_url']:
if isinstance(payload['audio_url'], list):
for audio in payload['audio_url']:
path = os.path.join(self.imgs_dir, audio)
abm.message.append(Record(file=path, path=path))
else:
path = os.path.join(self.imgs_dir, payload['audio_url'])
abm.message.append(Record(file=path, path=path))
logger.debug(f"WebChatAdapter: {abm.message}")
message_str = payload['message']
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = WebChatMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id
)
self.commit_event(message_event)

View File

@@ -0,0 +1,50 @@
import os
import uuid
import base64
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import web_chat_back_queue
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.imgs_dir = "data/webchat/imgs"
os.makedirs(self.imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
if not message:
web_chat_back_queue.put_nowait(None)
return
cid = self.session_id.split("!")[-1]
for comp in message.chain:
if isinstance(comp, Plain):
web_chat_back_queue.put_nowait((comp.text, cid))
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file.startswith("base64://"):
base64_str = comp.file[9:]
image_data = base64.b64decode(base64_str)
with open(path, "wb") as f:
f.write(image_data)
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
else:
logger.debug(f"webchat 忽略: {comp.type}")
web_chat_back_queue.put_nowait(None)
await super().send(message)

View File

@@ -1,9 +1,10 @@
from .provider import Provider, Personality
from .provider import Provider, Personality, STTProvider
from .provider_metadata import ProviderMetaData
from .entites import ProviderMetaData
__all__ = [
"Provider",
"Personality",
"ProviderMetaData",
"STTProvider"
]

View File

@@ -0,0 +1,64 @@
import enum
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from astrbot.core.db.po import Conversation
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
@dataclass
class ProviderMetaData():
type: str
'''提供商适配器名称,如 openai, ollama'''
desc: str = ""
'''提供商适配器描述.'''
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
default_config_tmpl: dict = None
'''平台的默认配置模板'''
provider_display_name: str = None
'''显示在 WebUI 配置页中的提供商名称,如空则是 type'''
@dataclass
class ProviderRequest():
prompt: str
'''提示词'''
session_id: str = ""
'''会话 ID'''
image_urls: List[str] = None
'''图片 URL 列表'''
func_tool: FuncCall = None
'''工具'''
contexts: List = None
'''上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
'''
system_prompt: str = ""
'''系统提示词'''
conversation: Conversation = None
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
def __str__(self):
return self.__repr__()
@dataclass
class LLMResponse:
role: str
'''角色, assistant, tool, err'''
completion_text: str = ""
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
'''工具调用参数'''
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None

View File

@@ -1,25 +1,9 @@
import json
import textwrap
from typing import Awaitable, Dict, List
from typing import Dict, List, Awaitable
from dataclasses import dataclass
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
class FuncNotFoundError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
@dataclass
class FuncTool:
"""
@@ -29,8 +13,8 @@ class FuncTool:
name: str
parameters: Dict
description: str
func_obj: Awaitable
module_name: str = None
handler: Awaitable
handler_module_path: str = None # 必须要保留这个handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
active: bool = True
'''是否激活'''
@@ -56,8 +40,7 @@ class FuncCall:
name: str,
func_args: list,
desc: str,
func_obj: Awaitable,
module_name: str = None,
handler: Awaitable,
) -> None:
"""
为函数调用function-calling / tools-use添加工具
@@ -80,8 +63,7 @@ class FuncCall:
name=name,
parameters=params,
description=desc,
func_obj=func_obj,
module_name=module_name,
handler=handler,
)
self.func_list.append(_func)
@@ -119,10 +101,59 @@ class FuncCall:
}
)
return _l
def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
tools = []
for f in self.func_list:
if not f.active:
continue
# Convert internal format to Anthropic style
tool = {
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": f.parameters.get("properties", {}),
# Keep the required field from the original parameters if it exists
"required": f.parameters.get("required", [])
}
}
tools.append(tool)
return tools
def get_func_desc_google_genai_style(self) -> Dict:
declarations = {}
tools = []
for f in self.func_list:
if not f.active:
continue
func_declaration = {
"name": f.name,
"description": f.description
}
# 检查并添加非空的properties参数
params = f.parameters if isinstance(f.parameters, dict) else {}
if params.get("properties", {}):
func_declaration["parameters"] = params
tools.append(func_declaration)
if tools:
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f["name"],
@@ -179,12 +210,19 @@ class FuncCall:
# 调用函数
tool_callable = None
for func in self.func_list:
if func["name"] == func_name:
tool_callable = func["func_obj"]
if func.name == func_name:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
raise FuncNotFoundError(f"Request function {func_name} not found.")
raise Exception(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)

View File

@@ -1,13 +0,0 @@
from typing import Dict, List
from dataclasses import dataclass
@dataclass
class LLMResponse:
role: str
'''角色'''
completion_text: str = None
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = None
'''工具调用参数'''
tools_call_name: List[str] = None
'''工具调用名称'''

View File

@@ -1,52 +1,244 @@
import traceback
import uuid
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
from .register import provider_cls_map, llm_tools
from astrbot.core import logger
from astrbot.core import logger, sp
class ProviderManager():
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', [])
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
self.personas: List[Personality] = []
self.selected_default_persona = None
for persona in self.persona_configs:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
bd_processed = []
mid_processed = ""
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append({
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None # 不持久化到 db
})
user_turn = not user_turn
if mood_imitation_dialogs:
if len(mood_imitation_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
mood_imitation_dialogs = []
user_turn = True
for dialog in mood_imitation_dialogs:
role = "A" if user_turn else "B"
mid_processed += f"{role}: {dialog}\n"
if not user_turn:
mid_processed += '\n'
user_turn = not user_turn
try:
persona = Personality(
**persona,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed=mid_processed
)
if persona['name'] == self.default_persona_name:
self.selected_default_persona = persona
self.personas.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
if not self.selected_default_persona:
self.selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed=""
)
self.personas.append(self.selected_default_persona)
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
# kdb(experimental)
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
changed = False
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
if provider_cfg['id'] in self.loaded_ids:
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}"
logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}")
provider_cfg['id'] = new_id
changed = True
self.loaded_ids[provider_cfg['id']] = True
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
try:
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "anthropic_chat_completion":
from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "dashscope":
from .sources.dashscope_source import ProviderDashscope as ProviderDashscope
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
if changed:
try:
config.save_config()
except Exception as e:
logger.warning(f"保存配置文件失败:{e}")
async def initialize(self):
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_tts_settings.get("enable", False)
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
provider_metadata = provider_cls_map[provider_config['type']]
logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.db_helper,
self.provider_settings.get('persistant_history', True),
self.selected_default_persona
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
if len(self.provider_insts) > 0:
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
self.curr_provider_inst = self.provider_insts[0]
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
def get_insts(self):
return self.provider_insts
return self.provider_insts
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()

View File

@@ -5,51 +5,34 @@ from typing import List
from astrbot.core.db import BaseDatabase
from astrbot.core import logger
from typing import TypedDict
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.llm_response import LLMResponse
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
begin_dialogs: List[str] = []
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
@dataclass
class ProviderMeta():
id: str
model: str
type: str
class Provider(abc.ABC):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
persistant_history: bool = True,
db_helper: BaseDatabase = None
) -> None:
class AbstractProvider(abc.ABC):
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
'''当前使用的模型名称'''
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_config = provider_config
self.provider_settings = provider_settings
self.curr_personality = Personality(prompt=provider_settings['default_personality'])
'''维护了当前的使用的 persona即人格。'''
self.db_helper = db_helper
'''用于持久化的数据库操作对象。'''
if persistant_history:
# 读取历史记录
try:
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
@@ -59,6 +42,31 @@ class Provider(abc.ABC):
'''获得当前使用的模型名称'''
return self.model_name
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class Provider(AbstractProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
persistant_history: bool = True,
db_helper: BaseDatabase = None,
default_persona: Personality = None
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
'''维护了当前的使用的 persona即人格。可能为 None'''
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError()
@@ -76,22 +84,6 @@ class Provider(abc.ABC):
'''获得支持的模型列表'''
raise NotImplementedError()
@abc.abstractmethod
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
'''获取人类可读的上下文
page 从 1 开始
Example:
["User: 你好", "Assistant: 你好!"]
Return:
contexts: List[str]: 上下文列表
total_pages: int: 总页数
'''
raise NotImplementedError()
@abc.abstractmethod
async def text_chat(self,
prompt: str,
@@ -99,37 +91,62 @@ class Provider(abc.ABC):
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts: List=None,
system_prompt: str=None,
**kwargs) -> LLMResponse:
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Args:
prompt: 提示词
session_id: 会话 ID
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
kwargs: 其他参数
Notes:
- 可以选择性地传入 session_id如果传入了 session_id将会使用 session_id 对应的上下文进行对话,
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
- 如果传入了 image_urls将会在对话时附上图片。如果模型不支持图片输入将会抛出错误。
- 如果传入了 tools将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling将会抛出错误。
- 如果传入了 contexts将会**直接**使用所提供的 contexts 进行对话。
传入此值通常意味着你需要自己维护 contextAstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。
'''
raise NotImplementedError()
async def pop_record(self, context: List):
'''
弹出 context 第一条非系统提示词对话记录
'''
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
else:
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
break
for idx in reversed(indexs_to_pop):
context.pop(idx)
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def forget(self, session_id: str) -> bool:
'''重置某一个 session_id 的上下文'''
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
'''获取文本的音频,返回音频文件路径'''
raise NotImplementedError()

View File

@@ -1,6 +0,0 @@
from dataclasses import dataclass
@dataclass
class ProviderMetaData():
type: str # 提供商适配器名称,如 openai, ollama
desc: str = "" # 提供商适配器描述.

View File

@@ -1,69 +1,47 @@
import docstring_parser
from typing import List, Dict, Type, Awaitable
from .provider_metadata import ProviderMetaData
from typing import List, Dict, Type
from .entites import ProviderMetaData, ProviderType
from astrbot.core import logger
from .tool import FuncCall, SUPPORTED_TYPES
from .func_tool_manager import FuncCall
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
provider_cls_map: Dict[str, Type] = {}
'''维护了 Provider 类型名称和 Provider 的映射'''
provider_cls_map: Dict[str, ProviderMetaData] = {}
'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
llm_tools = FuncCall()
def register_provider_adapter(provider_type_name: str, desc: str):
def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
default_config_tmpl: dict = None,
provider_display_name: str = None
):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = provider_type_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
if 'id' not in default_config_tmpl:
default_config_tmpl['id'] = provider_type_name
pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
provider_type=provider_type,
cls_type=cls,
default_config_tmpl=default_config_tmpl,
provider_display_name=provider_display_name
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = cls
logger.debug(f"Provider {provider_type_name} 已注册")
provider_cls_map[provider_type_name] = pm
logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
return cls
return decorator
def register_llm_tool(name: str = None):
'''为函数调用function-calling / tools-use添加工具。
请务必按照以下格式编写一个工具包括函数注释AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
可接受的参数类型有string, number, object, array, boolean。
'''
name_ = name
def decorator(func_obj: Awaitable):
llm_tool_name = name_ if name_ else func_obj.__name__
module_name = func_obj.__module__
docstring = docstring_parser.parse(func_obj.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
args.append({
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description
})
llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return func_obj
return decorator

View File

@@ -0,0 +1,189 @@
from typing import List
from mimetypes import guess_type
from anthropic import AsyncAnthropic
from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("anthropic_chat_completion", "Anthropic Claude API 提供商适配器")
class ProviderAnthropic(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True,
default_persona: Personality = None
) -> None:
# Skip OpenAI's __init__ and call Provider's __init__ directly
Provider.__init__(self, provider_config, provider_settings, persistant_history, db_helper, default_persona)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.client = AsyncAnthropic(
api_key=self.chosen_api_key,
timeout=self.timeout,
base_url=self.base_url
)
self.set_model(provider_config['model_config']['model'])
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_anthropic_style()
if tool_list:
payloads['tools'] = tool_list
completion = await self.client.messages.create(
**payloads,
stream=False
)
assert isinstance(completion, Message)
logger.debug(f"completion: {completion}")
if len(completion.content) == 0:
raise Exception("API 返回的 completion 为空。")
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
# 选最后一条消息如果要进行函数调用anthropic会先返回文本消息的思维链然后再返回函数调用请求
content = completion.content[-1]
llm_response = LLMResponse("assistant")
if content.type == "text":
# text completion
completion_text = str(content.text).strip()
llm_response.completion_text = completion_text
# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
# tools call (function calling)
args_ls = []
func_name_ls = []
func_name_ls.append(content.name)
args_ls.append(content.input)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
if not prompt:
prompt = "<image>"
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
payloads = {
"messages": context_query,
**model_config
}
# Anthropic has a different way of handling system prompts
if system_prompt:
payloads['system'] = system_prompt
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
try:
await self.pop_record(context_query)
response = await self.client.messages.create(
messages=context_query,
**model_config
)
llm_response = LLMResponse("assistant")
llm_response.completion_text = response.content[0].text
llm_response.raw_completion = response
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
else:
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''组装上下文,支持文本和图片'''
if not image_urls:
return {"role": "user", "content": text}
content = []
content.append({"type": "text", "text": text})
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine
content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_data.split("base64,")[1] if "base64," in image_data else image_data
}
})
return {"role": "user", "content": content}

View File

@@ -0,0 +1,128 @@
import asyncio
import functools
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
class ProviderDashscope(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
default_persona: Personality = None,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空。")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空。")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.model_name = "dashscope"
self.variables: dict = provider_config.get("variables", {})
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
contexts_no_img = await self._remove_image_from_context(contexts)
context_query = [*contexts_no_img, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# 调用阿里云百炼 API
partial = functools.partial(
Application.call,
app_id=self.app_id,
api_key=self.api_key,
messages=context_query,
biz_params=payload_vars or None,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
# 调用阿里云百炼 API
partial = functools.partial(
Application.call,
app_id=self.app_id,
promtp=prompt,
api_key=self.api_key,
biz_params=payload_vars or None,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
)
return LLMResponse(
role="err",
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
)
output_text = response.output.get("text", "")
return LLMResponse(role="assistant", completion_text=output_text)
async def forget(self, session_id):
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
async def terminate(self):
pass

View File

@@ -0,0 +1,152 @@
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger, sp
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
default_persona: Personality=None
) -> None:
super().__init__(
provider_config, provider_settings, persistant_history, db_helper, default_persona
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_client = DifyAPIClient(self.api_key, api_base)
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
self.model_name = "dify"
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
self.dify_query_input_key = provider_config.get("dify_query_input_key", "astrbot_text_query")
self.variables: dict = provider_config.get("variables", {})
if not self.dify_query_input_key:
self.dify_query_input_key = "astrbot_text_query"
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.conversation_ids = {}
'''记录当前 session id 的对话 ID'''
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
result = ""
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
file_response = await self.api_client.file_upload(image_path, user=session_id)
if 'id' not in file_response:
logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。")
continue
files_payload.append({
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response['id'],
})
else:
# TODO: 处理更多情况
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
try:
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk['event'] == "message" or \
chunk['event'] == "agent_message":
result += chunk['answer']
if not conversation_id:
self.conversation_ids[session_id] = chunk['conversation_id']
conversation_id = chunk['conversation_id']
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout
):
match chunk['event']:
case "workflow_started":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
case "node_finished":
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
case "workflow_finished":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
if chunk['data']['error']:
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
if self.workflow_output_key not in chunk['data']['outputs']:
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
result = chunk['data']['outputs'][self.workflow_output_key]
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
except Exception as e:
logger.error(f"Dify 请求失败:{str(e)}")
return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}")
return LLMResponse(role="assistant", completion_text=result)
async def forget(self, session_id):
self.conversation_ids[session_id] = ""
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("Dify 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 Dify 的历史消息记录。")
async def terminate(self):
await self.api_client.close()

View File

@@ -0,0 +1,105 @@
import uuid
import ormsgpack
from pydantic import BaseModel, conint
from httpx import AsyncClient
from typing import Annotated, Literal
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# 音频格式
format: Literal["wav", "pcm", "mp3"] = "mp3"
mp3_bitrate: Literal[64, 128, 192] = 128
# 参考音频
references: list[ServeReferenceAudio] = []
# 参考模型 ID
# 例如 https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# 其中reference_id为 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# 对中英文文本进行标准化,这可以提高数字的稳定性
normalize: bool = True
# 平衡模式将延迟减少到300毫秒但可能会降低稳定性
latency: Literal["normal", "balanced"] = "normal"
@register_provider_adapter(
"fishaudio_tts_api", "FishAudio TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
)
class ProviderFishAudioTTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.character: str = provider_config.get("fishaudio-tts-character", "可莉")
self.api_base: str = provider_config.get(
"api_base", "https://api.fish-audio.cn/v1"
)
self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}",
}
self.set_model(provider_config.get("model", None))
async def _get_reference_id_by_character(self, character: str) -> str:
"""
获取角色的reference_id
Args:
character: 角色名称
Returns:
reference_id: 角色的reference_id
exception:
APIException: 获取语音角色列表为空
"""
sort_options = ["score", "task_count", "created_at"]
async with AsyncClient(base_url=self.api_base.replace("/v1", "")) as client:
for sort_by in sort_options:
params = {"title": character, "sort_by": sort_by}
response = await client.get(
"/model", params=params, headers=self.headers
)
resp_data = response.json()
if resp_data["total"] == 0:
continue
for item in resp_data["items"]:
if character in item["title"]:
return item["_id"]
return None
async def _generate_request(self, text: str) -> dict:
return ServeTTSRequest(
text=text,
format="wav",
reference_id=await self._get_reference_id_by_character(self.character),
)
async def get_audio(self, text: str) -> str:
path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav"
self.headers["content-type"] = "application/msgpack"
request = await self._generate_request(text)
async with AsyncClient(base_url=self.api_base).stream(
"POST",
"/tts",
headers=self.headers,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
) as response:
if response.headers["content-type"] == "audio/wav":
with open(path, "wb") as f:
async for chunk in response.aiter_bytes():
f.write(chunk)
return path
text = await response.aread()
raise Exception(f"Fish Audio API请求失败: {text}")

View File

@@ -0,0 +1,284 @@
import base64
import aiohttp
import random
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
class SimpleGoogleGenAIClient():
def __init__(self, api_key: str, api_base: str, timeout: int=120) -> None:
self.api_key = api_key
if api_base.endswith("/"):
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession(trust_env=True)
self.timeout = timeout
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
async with self.client.get(request_url, timeout=self.timeout) as resp:
response = await resp.json()
models = []
for model in response["models"]:
if 'generateContent' in model["supportedGenerationMethods"]:
models.append(model["name"].replace("models/", ""))
return models
async def generate_content(
self,
contents: List[dict],
model: str="gemini-1.5-flash",
system_instruction: str="",
tools: dict=None
):
payload = {}
if system_instruction:
payload["system_instruction"] = {
"parts": {"text": system_instruction}
}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
logger.debug(f"payload: {payload}")
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
async with self.client.post(request_url, json=payload, timeout=self.timeout) as resp:
if "application/json" in resp.headers.get("Content-Type"):
try:
response = await resp.json()
except Exception as e:
text = await resp.text()
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise e
return response
else:
text = await resp.text()
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise Exception("Gemini 返回了非 json 数据: ")
@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器")
class ProviderGoogleGenAI(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True,
default_persona: Personality=None
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 180)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.client = SimpleGoogleGenAIClient(
api_key=self.chosen_api_key,
api_base=provider_config.get("api_base", None),
timeout=self.timeout
)
self.set_model(provider_config['model_config']['model'])
async def get_models(self):
return await self.client.models_list()
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
tool = tools.get_func_desc_google_genai_style()
if not tool:
tool = None
system_instruction = ""
for message in payloads["messages"]:
if message["role"] == "system":
system_instruction = message["content"]
break
google_genai_conversation = []
for message in payloads["messages"]:
if message["role"] == "user":
if isinstance(message["content"], str):
if not message['content']:
message['content'] = "<empty_content>"
google_genai_conversation.append({
"role": "user",
"parts": [{"text": message["content"]}]
})
elif isinstance(message["content"], list):
# images
parts = []
for part in message["content"]:
if part["type"] == "text":
if not part["text"]:
part["text"] = "<empty_content>"
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append({"inline_data": {
"mime_type": "image/jpeg",
"data": part["image_url"]["url"].replace("data:image/jpeg;base64,", "") # base64
}})
google_genai_conversation.append({
"role": "user",
"parts": parts
})
elif message["role"] == "assistant":
if not message["content"]:
message["content"] = "<empty_content>"
google_genai_conversation.append({
"role": "model",
"parts": [{"text": message["content"]}]
})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool
)
logger.debug(f"result: {result}")
if "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
candidates = result["candidates"][0]['content']['parts']
llm_response = LLMResponse("assistant")
for candidate in candidates:
if 'text' in candidate:
llm_response.completion_text += candidate['text']
elif 'functionCall' in candidate:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate['functionCall']['args'])
llm_response.tools_call_name.append(candidate['functionCall']['name'])
llm_response.completion_text = llm_response.completion_text.strip()
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**model_config
}
llm_response = None
retry = 10
keys = self.api_keys.copy()
chosen_key = random.choice(keys)
for i in range(retry):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
elif "Function calling is not enabled" in str(e):
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
elif "429" in str(e) or "API key not valid" in str(e):
keys.remove(chosen_key)
if len(keys) > 0:
chosen_key = random.choice(keys)
logger.info(f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}...")
continue
else:
logger.error(f"A检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}...")
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
else:
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
raise e
return llm_response
def get_current_key(self) -> str:
return self.client.api_key
def get_keys(self) -> List[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
组装上下文。
'''
if image_urls:
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
else:
return {"role": "user","content": text}
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
return ''

View File

@@ -2,112 +2,101 @@ import json
import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from .. import Provider, Personality
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from astrbot import logger
from ..register import register_provider_adapter
@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型")
@register_provider_adapter(
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
)
class LLMTunerModelLoader(Provider):
def __init__(
self,
provider_config: dict,
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
db_helper: BaseDatabase,
persistant_history=True,
default_persona=None,
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
self.base_model_path = provider_config['base_model_path']
self.adapter_model_path = provider_config['adapter_model_path']
self.model = ChatModel({
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config['llmtuner_template'],
"finetuning_type": provider_config['finetuning_type'],
"quantization_bit": provider_config['quantization_bit'],
})
self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path))
super().__init__(
provider_config, provider_settings, persistant_history, db_helper, default_persona
)
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
provider_config["adapter_model_path"]
):
raise FileNotFoundError("模型文件路径不存在。")
self.base_model_path = provider_config["base_model_path"]
self.adapter_model_path = provider_config["adapter_model_path"]
self.model = ChatModel(
{
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config["llmtuner_template"],
"finetuning_type": provider_config["finetuning_type"],
"quantization_bit": provider_config["quantization_bit"],
}
)
self.set_model(
os.path.basename(self.base_model_path)
+ "_"
+ os.path.basename(self.adapter_model_path)
)
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
"""
组装上下文。
'''
"""
return {"role": "user", "content": text}
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str] = None,
tools = None,
contexts: List=None,
**kwargs) -> str:
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = [],
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
system_prompt = ""
if not contexts:
contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}]
system_prompt = self.curr_personality["prompt"]
else:
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(contexts):
if context["role"] == "system":
system_idxs.append(idx)
for idx in reversed(system_idxs):
system_prompt += " " + contexts.pop(idx)["content"]
new_record = {"role": "user", "content": prompt}
query_context = [*contexts, new_record]
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(query_context):
if context["role"] == "system":
system_idxs.append(idx)
if '_no_save' in context:
del context['_no_save']
logger.debug(f"请求上下文:{contexts}")
logger.debug(f"请求 System Prompt{system_prompt}")
for idx in reversed(system_idxs):
system_prompt += " " + query_context.pop(idx)["content"]
conf = {
"messages": contexts,
"messages": query_context,
"system": system_prompt,
}
if tools:
conf['tools'] = tools
responses = await self.model.achat(**conf)
logger.debug(f"返回上下文:{responses}")
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
self.session_memory[session_id].append({"role": "user", "content": prompt})
self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text})
return responses[-1].response_text
if func_tool:
tool_list = func_tool.get_func_desc_openai_style()
if tool_list:
conf['tools'] = tool_list
responses = await self.model.achat(**conf)
llm_response = LLMResponse("assistant", responses[-1].response_text)
return llm_response
async def forget(self, session_id):
logger.info("llmtuner reset")
self.session_memory[session_id] = []
return True
async def get_current_key(self):
return "none"
async def set_key(self, key):
pass
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
return [self.get_model()]

View File

@@ -1,20 +1,19 @@
import traceback
import base64
import json
import datetime
import os
from openai import AsyncOpenAI, NOT_GIVEN
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import NotFoundError
from openai._exceptions import NotFoundError, UnprocessableEntityError
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.llm_response import LLMResponse
from astrbot.core.provider.entites import LLMResponse
@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器")
class ProviderOpenAIOfficial(Provider):
@@ -23,97 +22,74 @@ class ProviderOpenAIOfficial(Provider):
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
persistant_history = True,
default_persona: Personality = None
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.enable_datetime = provider_config.get("datetime_system_prompt", True)
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config['model_config']['model'])
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
# 适配 azure openai #332
if "api_version" in provider_config:
# 使用 azure api
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
base_url=provider_config.get("api_base", None),
timeout=self.timeout
)
else:
# 使用 openai api
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=self.timeout
)
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
contexts.append(f"Assistant: {record['content']}")
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
model_config = provider_config.get("model_config", {})
model = model_config.get("model", "unknown")
self.set_model(model)
async def get_models(self):
try:
models_str = []
models = await self.client.models.list()
models = models.data
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
return models_str
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
payloads["tools"] = tools.get_func_desc_openai_style()
tool_list = tools.get_func_desc_openai_style()
if tool_list:
payloads['tools'] = tool_list
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion.usage}")
logger.debug(f"completion: {completion}")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
llm_response = LLMResponse("assistant")
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
return LLMResponse("assistant", completion_text)
elif choice.message.tool_calls:
llm_response.completion_text = completion_text
if choice.message.tool_calls:
# tools call (function calling)
args_ls = []
func_name_ls = []
@@ -123,66 +99,136 @@ class ProviderOpenAIOfficial(Provider):
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
else:
raise Exception("Internal Error")
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
system_prompt = ""
if self.curr_personality["prompt"]:
system_prompt = self.curr_personality["prompt"]
if self.enable_datetime:
system_prompt += f"Current datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
else:
context_query = contexts
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
logger.debug(f"请求上下文:{context_query}, {self.get_model()}")
if choice.finish_reason == 'content_filter':
raise Exception("API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。")
if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
}
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
if llm_response.role == "assistant":
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
llm_response.raw_completion = completion
return llm_response
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
async def text_chat(
self,
prompt: str,
session_id: str=None,
image_urls: List[str]=[],
func_tool: FuncCall=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**model_config
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads['messages'] = new_contexts
context_query = new_contexts
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
# 重试 10 次
retry_cnt = 20
while retry_cnt > 0:
logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}")
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads['messages'] = new_contexts
llm_response = await self._query(payloads, func_tool)
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
elif 'does not support Function Calling' in str(e) \
or 'does not support tools' in str(e) \
or 'Function call is not supported' in str(e) \
or 'Function calling is not enabled' in str(e) \
or 'Tool calling is not supported' in str(e) \
or 'No endpoints found that support tool use' in str(e) \
or 'model does not support function calling' in str(e) \
or ('tool' in str(e) and 'support' in str(e).lower()) \
or ('function' in str(e) and 'support' in str(e).lower()):
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if 'tool' in str(e).lower() and 'support' in str(e).lower():
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
if 'Connection error.' in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}")
raise e
return llm_response
async def _remove_image_from_context(self, contexts: List):
'''
从上下文中删除所有带有 image 的记录
'''
new_contexts = []
flag = False
for context in contexts:
if flag:
flag = False # 删除 image 后下一条LLM 响应)也要删除
continue
if isinstance(context['content'], list):
flag = True
# continue
new_content = []
for item in context['content']:
if isinstance(item, dict) and 'image_url' in item:
continue
new_content.append(item)
if not new_content:
# 用户只发了图片
new_content = [{"type": "text", "text": "[图片]"}]
context['content'] = new_content
new_contexts.append(context)
return new_contexts
def get_current_key(self) -> str:
return self.client.api_key
@@ -202,8 +248,14 @@ class ProviderOpenAIOfficial(Provider):
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
else:

View File

@@ -0,0 +1,40 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
class ProviderOpenAITTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.voice = provider_config.get("openai-tts-voice", "alloy")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def get_audio(self, text: str) -> str:
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice=self.voice,
response_format='wav',
input=text
) as response:
with open(path, 'wb') as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
return path

View File

@@ -0,0 +1,74 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperAPI(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
'''only supports mp3, mp4, mpeg, m4a, wav, webm'''
is_tencent = False
if audio_url.startswith("http"):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
await download_file(audio_url, path)
audio_url = path
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
result = await self.client.audio.transcriptions.create(
model=self.model_name,
file=open(audio_url, "rb"),
)
return result.text

View File

@@ -0,0 +1,72 @@
import uuid
import os
import asyncio
import whisper
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperSelfHost(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model", None))
self.model = None
async def initialize(self):
loop = asyncio.get_event_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
logger.info("Whisper 模型加载完成。")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_event_loop()
is_tencent = False
if audio_url.startswith("http"):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
await download_file(audio_url, path)
audio_url = path
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result['text']

View File

@@ -0,0 +1,78 @@
import traceback
from astrbot.core.db import BaseDatabase
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
class ProviderZhipu(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True,
default_persona = None
) -> None:
super().__init__(provider_config, provider_settings, db_helper, persistant_history, default_persona)
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
model = self.get_model()
# glm-4v-flash 只支持一张图片
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
new_context_query_ = []
for i in range(0, len(context_query) - 1, 2):
if isinstance(context_query[i].get("content", ""), list):
continue
new_context_query_.append(context_query[i])
new_context_query_.append(context_query[i+1])
new_context_query_.append(context_query[-1]) # 保留最后一条记录
context_query = new_context_query_
logger.debug(context_query)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**model_cfgs
}
try:
llm_response = await self._query(payloads, func_tool)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 10
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
else:
raise e

View File

@@ -0,0 +1,25 @@
from typing import List
from openai import AsyncOpenAI
class SimpleOpenAIEmbedding():
def __init__(
self,
model,
api_key,
api_base=None,
) -> None:
self.client = AsyncOpenAI(
api_key=api_key,
base_url=api_base
)
self.model = model
async def get_embedding(self, text) -> List[float]:
'''
获取文本的嵌入
'''
embedding = await self.client.embeddings.create(
input=text,
model=self.model
)
return embedding.data[0].embedding

View File

@@ -0,0 +1,92 @@
import os
from typing import List, Dict
from astrbot.core import logger
from .store import Store
from astrbot.core.config import AstrBotConfig
class KnowledgeDBManager():
def __init__(self, astrbot_config: AstrBotConfig) -> None:
self.db_path = "data/knowledge_db/"
self.config = astrbot_config.get("knowledge_db", {})
self.astrbot_config = astrbot_config
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
self.store_insts: Dict[str, Store] = {}
for name, cfg in self.config.items():
if cfg["strategy"] == "embedding":
logger.info(f"加载 Chroma Vector Store{name}")
try:
from .store.chroma_db import ChromaVectorStore
except ImportError as ie:
logger.error(f"{ie} 可能未安装 chromadb 库。")
continue
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
else:
logger.error(f"不支持的策略:{cfg['strategy']}")
async def list_knowledge_db(self) -> List[str]:
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
async def create_knowledge_db(self, name: str, config: Dict):
'''
config 格式:
```
{
"strategy": "embedding", # 目前只支持 embedding
"chunk_method": {
"strategy": "fixed",
"chunk_size": 100,
"overlap_size": 10
},
"embedding_config": {
"strategy": "openai",
"base_url": "",
"model": "",
"api_key": ""
}
}
```
'''
if name in self.config:
raise ValueError(f"知识库已存在:{name}")
self.config[name] = config
self.astrbot_config["knowledge_db"] = self.config
self.astrbot_config.save_config()
async def insert_record(self, name: str, text: str):
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
ret = []
match self.config[name]["chunk_method"]['strategy']:
case "fixed":
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
case _:
pass
for chunk in ret:
await self.store_insts[name].save(chunk)
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
inst = self.store_insts[name]
return await inst.query(query, top_n)
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - chunk_overlap
return chunks

View File

@@ -0,0 +1,8 @@
from typing import List
class Store():
async def save(self, text: str):
pass
async def query(self, query: str, top_n: int = 3) -> List[str]:
pass

View File

@@ -0,0 +1,39 @@
import chromadb
import uuid
from typing import List, Dict
from astrbot.api import logger
from ..embedding.openai_source import SimpleOpenAIEmbedding
from . import Store
class ChromaVectorStore(Store):
def __init__(self, name: str, embedding_cfg: Dict) -> None:
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
self.collection = self.chroma_client.get_or_create_collection(name=name)
self.embedding = None
if embedding_cfg["strategy"] == "openai":
self.embedding = SimpleOpenAIEmbedding(
model=embedding_cfg["model"],
api_key=embedding_cfg["api_key"],
api_base=embedding_cfg.get("base_url", None)
)
async def save(self, text: str, metadata: Dict = None):
logger.debug(f"Saving text: {text}")
embedding = await self.embedding.get_embedding(text)
self.collection.upsert(
documents=text,
metadatas=metadata,
ids=str(uuid.uuid4()),
embeddings=embedding
)
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
embedding = await self.embedding.get_embedding(query)
results = self.collection.query(
query_embeddings=embedding,
n_results=top_n,
where=metadata_filter
)
return results['documents'][0]

View File

@@ -1,3 +1,7 @@
'''
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
'''
from typing import Union
import os
import json

View File

@@ -1,23 +1,22 @@
from asyncio import Queue
from typing import List, TypedDict, Union
from typing import List, Union
from astrbot.core.provider import Provider
from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata
from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
class StarCommand(TypedDict):
full_command_name: str
command_name: str
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class Context:
'''
@@ -38,43 +37,39 @@ class Context:
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
def __init__(self,
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
self._config = config
self._db = db
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
self.conversation_manager = conversation_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
'''根据插件名获取插件的 Metadata'''
for star in star_registry:
if star.name == star_name:
return star
def get_all_stars(self) -> List[StarMetadata]:
'''获取当前载入的所有插件 Metadata 的列表'''
return star_registry
def get_llm_tool_manager(self) -> FuncCall:
'''
获取 LLM Tool Manager
'''
'''获取 LLM Tool Manager其用于管理注册的所有的 Function-calling tools'''
return self.provider_manager.llm_tools
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用function-calling / tools-use添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数event: AstrMessageEvent, context: Context。
'''
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
@@ -83,7 +78,18 @@ class Context:
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
@@ -95,81 +101,66 @@ class Context:
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
@param star_name: 插件Star名称。
@param command_name: 命令名称。
@param desc: 命令描述。
@param priority: 优先级。1-10。
@param awaitable: 异步处理函数。
'''
md = StarHandlerMetadata(
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_str=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
)
if use_regex:
md.event_filters.append(RegexFilter(
regex=command_name
))
else:
md.event_filters.append(CommandFilter(
command_name=command_name,
handler_md=md
))
star_handlers_registry.append(md)
star_handlers_map[md.handler_full_name] = md
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider。
注册一个 LLM Provider(Chat_Completion 类型)
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider。
'''
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider。
'''
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
return self.provider_manager.provider_insts
def get_all_tts_providers(self) -> List[TTSProvider]:
'''获取所有用于 TTS 任务的 Provider。'''
return self.provider_manager.tts_provider_insts
def get_all_stt_providers(self) -> List[STTProvider]:
'''获取所有用于 STT 任务的 Provider。'''
return self.provider_manager.stt_provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider。
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_using_tts_provider(self) -> TTSProvider:
'''
获取当前使用的用于 TTS 任务的 Provider。
'''
return self.provider_manager.curr_tts_provider_inst
def get_using_stt_provider(self) -> STTProvider:
'''
获取当前使用的用于 STT 任务的 Provider。
'''
return self.provider_manager.curr_stt_provider_inst
def get_config(self) -> AstrBotConfig:
'''
获取 AstrBot 配置信息。
'''
'''获取 AstrBot 的配置。'''
return self._config
def get_db(self) -> BaseDatabase:
'''
获取 AstrBot 数据库。
'''
'''获取 AstrBot 数据库。'''
return self._db
def get_event_queue(self) -> Queue:
@@ -202,6 +193,71 @@ class Context:
return True
return False
'''
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
'''
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用function-calling / tools-use添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数event: AstrMessageEvent, context: Context。
'''
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
@param star_name: 插件Star名称。
@param command_name: 命令名称。
@param desc: 命令描述。
@param priority: 优先级。1-10。
@param awaitable: 异步处理函数。
'''
md = StarHandlerMetadata(
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_path=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
)
if use_regex:
md.event_filters.append(RegexFilter(
regex=command_name
))
else:
md.event_filters.append(CommandFilter(
command_name=command_name,
handler_md=md
))
star_handlers_registry.append(md)
def register_task(self, task: Awaitable, desc: str):
'''
注册一个异步任务。

View File

@@ -1,19 +1,23 @@
import re
import inspect
from typing import List, Any, Type, Dict
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin
from .custom_filter import CustomFilter
from ..star_handler import StarHandlerMetadata
# 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter, ParameterValidationMixin):
class CommandFilter(HandlerFilter):
'''标准指令过滤器'''
def __init__(self, command_name: str, handler_md: StarHandlerMetadata = None):
def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None, parent_command_names: List[str] = [""]):
self.command_name = command_name
self.alias = alias if alias else set()
self.parent_command_names = parent_command_names
if handler_md:
self.init_handler_md(handler_md)
self.custom_filter_list: List[CustomFilter] = []
def print_types(self):
result = ""
@@ -22,6 +26,7 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
result += f"{k}({v.__name__}),"
else:
result += f"{k}({type(v).__name__})={v},"
result = result.rstrip(",")
return result
def init_handler_md(self, handle_md: StarHandlerMetadata):
@@ -42,23 +47,79 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
def get_handler_md(self) -> StarHandlerMetadata:
return self.handler_md
def add_custom_filter(self, custom_filter: CustomFilter):
self.custom_filter_list.append(custom_filter)
def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
for custom_filter in self.custom_filter_list:
if not custom_filter.filter(event, cfg):
return False
return True
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
'''将参数列表 params 根据 param_type 转换为参数字典。
'''
result = {}
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
if i >= len(params):
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
# 是类型
raise ValueError(f"必要参数缺失。该指令完整参数: {self.print_types()}")
else:
# 是默认值
result[param_name] = param_type_or_default_val
else:
# 尝试强制转换
try:
if param_type_or_default_val is None:
if params[i].isdigit():
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, int):
result[param_name] = int(params[i])
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
raise ValueError(f"参数 {param_name} 类型错误。完整参数: {self.print_types()}")
return result
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_wake_up():
if not event.is_at_or_wake_command:
return False
if not self.custom_filter_ok(event, cfg):
return False
message_str = event.get_message_str().strip()
# 分割为列表(每个参数之间可能会有多个空格)
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
# 检查是否以指令开头
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
candidates = [self.command_name] + list(self.alias)
ok = False
for candidate in candidates:
for parent_command_name in self.parent_command_names:
if parent_command_name:
_full = f"{parent_command_name} {candidate}"
else:
_full = candidate
if message_str.startswith(f"{_full} ") or message_str == _full:
message_str = message_str[len(_full):].strip()
ok = True
break
if not ok:
return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 分割为列表
ls = message_str.split(" ")
# 去除空字符串
ls = [param for param in ls if param]
params = {}
try:
params = self.validate_and_convert_params(ls, self.handler_params)
except ValueError as e:
raise e

View File

@@ -6,65 +6,97 @@ from . import HandlerFilter
from .command import CommandFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from .custom_filter import CustomFilter
from ..star_handler import StarHandlerMetadata
# 指令组受到 wake_prefix 的制约。
class CommandGroupFilter(HandlerFilter):
def __init__(self, group_name: str):
def __init__(self, group_name: str, alias: set = None, parent_group: CommandGroupFilter = None):
self.group_name = group_name
self.alias = alias if alias else set()
self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = []
self.custom_filter_list: List[CustomFilter] = []
self.parent_group = parent_group
def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]):
self.sub_command_filters.append(sub_command_filter)
def add_custom_filter(self, custom_filter: CustomFilter):
self.custom_filter_list.append(custom_filter)
def get_complete_command_names(self) -> List[str]:
'''遍历父节点获取完整的指令名。
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。'''
parent_cmd_names = self.parent_group.get_complete_command_names() if self.parent_group else []
if not parent_cmd_names:
# 根节点
return [self.group_name] + list(self.alias)
result = []
candidates = [self.group_name] + list(self.alias)
for parent_cmd_name in parent_cmd_names:
for candidate in candidates:
result.append(parent_cmd_name + " " + candidate)
return result
# 以树的形式打印出来
def print_cmd_tree(self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "") -> str:
def print_cmd_tree(self,
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
prefix: str = "",
event: AstrMessageEvent = None,
cfg: AstrBotConfig = None,
) -> str:
result = ""
for sub_filter in sub_command_filters:
if isinstance(sub_filter, CommandFilter):
cmd_th = sub_filter.print_types()
result += f"{prefix}├── {sub_filter.command_name}"
if cmd_th:
result += f" ({cmd_th})"
else:
result += " (无参数指令)"
result += "\n"
custom_filter_pass = True
if event and cfg:
custom_filter_pass = sub_filter.custom_filter_ok(event, cfg)
if custom_filter_pass:
cmd_th = sub_filter.print_types()
result += f"{prefix}├── {sub_filter.command_name}"
if cmd_th:
result += f" ({cmd_th})"
else:
result += " (无参数指令)"
if sub_filter.handler_md and sub_filter.handler_md.desc:
result += f": {sub_filter.handler_md.desc}"
result += "\n"
elif isinstance(sub_filter, CommandGroupFilter):
result += f"{prefix}├── {sub_filter.group_name}"
result += "\n"
result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"")
custom_filter_pass = True
if event and cfg:
custom_filter_pass = sub_filter.custom_filter_ok(event, cfg)
if custom_filter_pass:
result += f"{prefix}├── {sub_filter.group_name}"
result += "\n"
result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"", event=event, cfg=cfg)
return result
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]:
if not event.is_wake_up():
return False, None
message_str = event.get_message_str().strip()
ls = re.split(r"\s+", message_str)
if ls[0] != self.group_name:
return False, None
# 改写 message_str
ls = ls[1:]
event.message_str = " ".join(ls)
event.message_str = event.message_str.strip()
if event.message_str == "":
# 当前还是指令组
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
for custom_filter in self.custom_filter_list:
if not custom_filter.filter(event, cfg):
return False
return True
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
# 判断当前指令组的自定义过滤器
if not self.custom_filter_ok(event, cfg):
return False
complete_command_names = self.get_complete_command_names()
if event.message_str.strip() in complete_command_names:
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
child_command_handler_md = None
for sub_filter in self.sub_command_filters:
if isinstance(sub_filter, CommandFilter):
if sub_filter.filter(event, cfg):
child_command_handler_md = sub_filter.get_handler_md()
return True, child_command_handler_md
elif isinstance(sub_filter, CommandGroupFilter):
ok, handler = sub_filter.filter(event, cfg)
if ok:
child_command_handler_md = handler
return True, child_command_handler_md
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree)
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False

View File

@@ -0,0 +1,53 @@
from abc import abstractmethod, ABCMeta
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
class CustomFilterMeta(ABCMeta):
def __and__(cls, other):
if not issubclass(other, CustomFilter):
raise TypeError("Operands must be subclasses of CustomFilter.")
return CustomFilterAnd(cls(), other())
def __or__(cls, other):
if not issubclass(other, CustomFilter):
raise TypeError("Operands must be subclasses of CustomFilter.")
return CustomFilterOr(cls(), other())
class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta):
def __init__(self, raise_error: bool = True, **kwargs):
self.raise_error = raise_error
@abstractmethod
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
''' 一个用于重写的自定义Filter '''
raise NotImplementedError
def __or__(self, other):
return CustomFilterOr(self, other)
def __and__(self, other):
return CustomFilterAnd(self, other)
class CustomFilterOr(CustomFilter):
def __init__(self, filter1: CustomFilter, filter2: CustomFilter):
super().__init__()
if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)):
raise ValueError("CustomFilter lass can only operate with other CustomFilter.")
self.filter1 = filter1
self.filter2 = filter2
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
return self.filter1.filter(event, cfg) or self.filter2.filter(event, cfg)
class CustomFilterAnd(CustomFilter):
def __init__(self, filter1: CustomFilter, filter2: CustomFilter):
super().__init__()
if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)):
raise ValueError("CustomFilter lass can only operate with other CustomFilter.")
self.filter1 = filter1
self.filter2 = filter2
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
return self.filter1.filter(event, cfg) and self.filter2.filter(event, cfg)

View File

@@ -6,8 +6,8 @@ from astrbot.core.config import AstrBotConfig
class PermissionType(enum.Flag):
'''权限类型。当选择 MEMBERADMIN 也可以通过。
'''
ADMIN = "admin"
MEMBER = "member"
ADMIN = enum.auto()
MEMBER = enum.auto()
class PermissionTypeFilter(HandlerFilter):
def __init__(self, permission_type: PermissionType, raise_error: bool = True):
@@ -19,7 +19,8 @@ class PermissionTypeFilter(HandlerFilter):
'''
if self.permission_type == PermissionType.ADMIN:
if not event.is_admin():
event.stop_event()
raise ValueError("您没有权限执行此操作。")
# event.stop_event()
# raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限操作管理员指令。")
return False
return True

View File

@@ -8,12 +8,14 @@ class PlatformAdapterType(enum.Flag):
AIOCQHTTP = enum.auto()
QQOFFICIAL = enum.auto()
VCHAT = enum.auto()
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT
GEWECHAT = enum.auto()
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT
ADAPTER_NAME_2_TYPE = {
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
"qq_official": PlatformAdapterType.QQOFFICIAL,
"vchat": PlatformAdapterType.VCHAT
"vchat": PlatformAdapterType.VCHAT,
"gewechat": PlatformAdapterType.GEWECHAT
}
class PlatformAdapterTypeFilter(HandlerFilter):

View File

@@ -8,6 +8,7 @@ from astrbot.core.config import AstrBotConfig
class RegexFilter(HandlerFilter):
'''正则表达式过滤器'''
def __init__(self, regex: str):
self.regex_str = regex
self.regex = re.compile(regex)
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:

View File

@@ -5,7 +5,13 @@ from .star_handler import (
register_event_message_type,
register_platform_adapter_type,
register_regex,
register_permission_type
register_permission_type,
register_custom_filter,
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
register_on_decorating_result,
register_after_message_sent
)
__all__ = [
@@ -15,5 +21,11 @@ __all__ = [
'register_event_message_type',
'register_platform_adapter_type',
'register_regex',
'register_permission_type'
'register_permission_type',
'register_custom_filter',
'register_on_llm_request',
'register_on_llm_response',
'register_llm_tool',
'register_on_decorating_result',
'register_after_message_sent'
]

View File

@@ -1,118 +1,194 @@
from __future__ import annotations
import docstring_parser
from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter
from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType
from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from ..filter.permission import PermissionTypeFilter, PermissionType
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
from ..filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core import logger
def get_handler_full_name(awatable: Awaitable) -> str:
def get_handler_full_name(awaitable: Awaitable) -> str:
'''获取 Handler 的全名'''
return f"{awatable.__module__}_{awatable.__name__}"
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMetadata:
def get_handler_or_create(
handler: Awaitable,
event_type: EventType,
dont_add = False,
**kwargs
) -> StarHandlerMetadata:
'''获取 Handler 或者创建一个新的 Handler'''
handler_full_name = get_handler_full_name(handler)
if handler_full_name in star_handlers_map:
return star_handlers_map[handler_full_name]
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
if md:
return md
else:
md = StarHandlerMetadata(
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_str=handler.__module__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[]
)
# 插件handler的附加额外信息
if handler.__doc__:
md.desc = handler.__doc__.strip()
if 'desc' in kwargs:
md.desc = kwargs['desc']
del kwargs['desc']
md.extras_configs = kwargs
if not dont_add:
star_handlers_registry.append(md)
star_handlers_map[handler_full_name] = md
return md
def register_command(command_name: str = None, *args):
'''注册一个 Command'''
def register_command(command_name: str = None, sub_command: str = None, alias: set = None, **kwargs):
'''注册一个 Command.
'''
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
# 子指令
new_command = CommandFilter(args[0], None)
parent_command_names = command_name.parent_group.get_complete_command_names()
logger.debug(f"parent_command_names: {parent_command_names}")
new_command = CommandFilter(sub_command, alias, None, parent_command_names=parent_command_names)
command_name.parent_group.add_sub_command_filter(new_command)
else:
# 裸指令
new_command = CommandFilter(command_name, None)
new_command = CommandFilter(command_name, alias, None)
add_to_event_filters = True
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable)
if not add_to_event_filters:
kwargs['sub_command'] = True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
new_command.init_handler_md(handler_md)
if add_to_event_filters:
# 裸指令
handler_md.event_filters.append(new_command)
handler_md.event_filters.append(new_command)
return awaitable
return decorator
def register_command_group(command_group_name: str = None, *args):
'''注册一个 CommandGroup'''
new_group = None
def register_custom_filter(custom_type_filter, *args, **kwargs):
'''注册一个自定义的 CustomFilter
Args:
custom_type_filter: 在裸指令时为CustomFilter对象
在指令组时为父指令的RegisteringCommandable对象即self或者command_group的返回
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
'''
add_to_event_filters = False
raise_error = True
# 判断是否是指令组指令组则添加到指令组的CommandGroupFilter对象中在waking_check的时候一起判断
if isinstance(custom_type_filter, RegisteringCommandable):
# 子指令, 此时函数为RegisteringCommandable对象的方法首位参数为RegisteringCommandable对象的self。
parent_register_commandable = custom_type_filter
custom_filter = args[0]
if len(args) > 1:
raise_error = args[1]
else:
# 裸指令
add_to_event_filters = True
custom_filter = custom_type_filter
if args:
raise_error = args[0]
if not isinstance(custom_filter, (CustomFilterAnd, CustomFilterOr)):
custom_filter = custom_filter(raise_error)
def decorator(awaitable):
# 裸指令子指令与指令组的区分指令组会因为标记跳过wake。
if not add_to_event_filters and isinstance(awaitable, RegisteringCommandable) or \
(add_to_event_filters and isinstance(awaitable, RegisteringCommandable)):
# 指令组 与 根指令组添加到本层的grouphandle中一起判断
awaitable.parent_group.add_custom_filter(custom_filter)
else:
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
if not add_to_event_filters and not isinstance(awaitable, RegisteringCommandable):
# 底层子指令
handle_full_name = get_handler_full_name(awaitable)
for sub_handle in parent_register_commandable.parent_group.sub_command_filters:
# 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname比如一个方法添加多个command装饰器
sub_handle_md = sub_handle.get_handler_md()
if sub_handle_md and sub_handle_md.handler_full_name == handle_full_name:
sub_handle.add_custom_filter(custom_filter)
else:
# 裸指令
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(custom_filter)
return awaitable
return decorator
def register_command_group(
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
):
'''注册一个 CommandGroup
'''
new_group = None
if isinstance(command_group_name, RegisteringCommandable):
# 子指令组
new_group = CommandGroupFilter(args[0])
new_group = CommandGroupFilter(sub_command, alias, parent_group=command_group_name.parent_group)
command_group_name.parent_group.add_sub_command_filter(new_group)
else:
# 根指令组
new_group = CommandGroupFilter(command_group_name)
add_to_event_filters = True
new_group = CommandGroupFilter(command_group_name, alias)
def decorator(obj):
if add_to_event_filters:
# 根指令组
handler_md = get_handler_or_create(obj)
handler_md.event_filters.append(new_group)
# 根指令组
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
return decorator
class RegisteringCommandable():
'''用于指令组级联注册'''
group = register_command_group
command = register_command
group: CommandGroupFilter = register_command_group
command: CommandFilter = register_command
custom_filter = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
def register_event_message_type(event_message_type: EventMessageType):
def register_event_message_type(event_message_type: EventMessageType, **kwargs):
'''注册一个 EventMessageType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
return awatable
return awaitable
return decorator
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType, **kwargs):
'''注册一个 PlatformAdapterType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type))
return awatable
return awaitable
return decorator
def register_regex(regex: str):
def register_regex(regex: str, **kwargs):
'''注册一个 Regex'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(RegexFilter(regex))
return awatable
return awaitable
return decorator
@@ -123,9 +199,121 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
permission_type: PermissionType
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(PermissionTypeFilter(permission_type, raise_error))
return awatable
return awaitable
return decorator
def register_on_llm_request(**kwargs):
'''当有 LLM 请求时的事件
Examples:
```py
from astrbot.api.provider import ProviderRequest
@on_llm_request()
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
request.system_prompt += "你是一个猫娘..."
```
请务必接收两个参数event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs)
return awaitable
return decorator
def register_on_llm_response(**kwargs):
'''当有 LLM 请求后的事件
Examples:
```py
from astrbot.api.provider import LLMResponse
@on_llm_response()
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
...
```
请务必接收两个参数event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs)
return awaitable
return decorator
def register_llm_tool(name: str = None):
'''为函数调用function-calling / tools-use添加工具。
请务必按照以下格式编写一个工具包括函数注释AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str):
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
可接受的参数类型有string, number, object, array, boolean。
返回值:
- 返回 str结果会被加入下一次 LLM 请求的 prompt 中,用于让 LLM 总结工具返回的结果
- 返回 None结果不会被加入下一次 LLM 请求的 prompt 中。
可以使用 yield 发送消息、终止事件。
发送消息:请参考文档。
终止事件:
```
event.stop_event()
yield
```
'''
name_ = name
def decorator(awaitable: Awaitable):
llm_tool_name = name_ if name_ else awaitable.__name__
docstring = docstring_parser.parse(awaitable.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
args.append({
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
return decorator
def register_on_decorating_result(**kwargs):
'''在发送消息前的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent, **kwargs)
return awaitable
return decorator
def register_after_message_sent(**kwargs):
'''在消息发送后的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent, **kwargs)
return awaitable
return decorator

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass
from dataclasses import dataclass, field
from astrbot.core.config import AstrBotConfig
star_registry: List[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {}
@@ -11,7 +12,7 @@ star_map: Dict[str, StarMetadata] = {}
@dataclass
class StarMetadata:
'''
Star 的元数据。
插件的元数据。
'''
name: str
author: str # 插件作者
@@ -20,18 +21,27 @@ class StarMetadata:
repo: str = None # 插件仓库地址
star_cls_type: type = None
'''Star 的类对象的类型'''
'''插件的类对象的类型'''
module_path: str = None
'''Star 的模块路径'''
'''插件的模块路径'''
star_cls: object = None
'''Star 的类对象'''
'''插件的类对象'''
module: ModuleType = None
'''Star 的模块对象'''
'''插件的模块对象'''
root_dir_name: str = None
'''Star 的根目录名'''
'''插件的目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留 Star'''
'''是否是 AstrBot 的保留插件'''
activated: bool = True
'''是否被激活'''
config: AstrBotConfig = None
'''插件配置'''
star_handler_full_names: List[str] = field(default_factory=list)
'''注册的 Handler 的全名列表'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"

View File

@@ -1,31 +1,117 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Awaitable, List, Dict
import enum
import heapq
from dataclasses import dataclass, field
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
star_handlers_registry: List[StarHandlerMetadata] = []
T = TypeVar('T', bound='StarHandlerMetadata')
class StarHandlerRegistry(Generic[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
_handlers = []
def append(self, handler: StarHandlerMetadata):
'''添加一个 Handler'''
if 'priority' not in handler.extras_configs:
handler.extras_configs['priority'] = 0
heapq.heappush(self._handlers, (-handler.extras_configs['priority'], handler))
self.star_handlers_map[handler.handler_full_name] = handler
def _print_handlers(self):
'''打印所有的 Handler'''
for _, handler in self._handlers:
print(handler.handler_full_name)
def get_handlers_by_event_type(self, event_type: EventType, only_activated=True) -> List[StarHandlerMetadata]:
'''通过事件类型获取 Handler'''
handlers = [
handler
for _, handler in self._handlers
if handler.event_type == event_type and
(not only_activated or (star_map[handler.handler_module_path] and star_map[handler.handler_module_path].activated))
]
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
'''通过 Handler 的全名获取 Handler'''
return self.star_handlers_map.get(full_name, None)
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for _, handler in self._handlers if handler.handler_module_path == module_name]
def clear(self):
'''清空所有的 Handler'''
self.star_handlers_map.clear()
self._handlers.clear()
def remove(self, handler: StarHandlerMetadata):
'''删除一个 Handler'''
# self._handlers.remove(handler)
for i, h in enumerate(self._handlers):
if h[1] == handler:
self._handlers.pop(i)
break
try:
del self.star_handlers_map[handler.handler_full_name]
except KeyError:
pass
def __iter__(self):
'''使 StarHandlerRegistry 支持迭代'''
return (handler for _, handler in self._handlers)
def __len__(self):
'''返回 Handler 的数量'''
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
class EventType(enum.Enum):
'''表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等
用于对 Handler 的职能分组。
'''
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnLLMResponseEvent = enum.auto() # LLM 响应后
OnDecoratingResultEvent = enum.auto() # 发送消息前
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
OnAfterMessageSentEvent = enum.auto() # 发送消息后
@dataclass
class StarHandlerMetadata():
'''描述一个 Star 所注册的某一个 Handler。'''
event_type: EventType
'''Handler 的事件类型'''
handler_full_name: str
'''格式为 f"{handler.__module__}_{handler.__name__}"'''
handler_name: str
'''Handler 的名字,也就是方法名'''
handler_module_str: str
handler_module_path: str
'''Handler 所在的模块路径。'''
handler: Awaitable
'''Handler 的函数对象,应当是一个异步函数'''
event_filters: List[HandlerFilter]
'''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件'''
'''一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件'''
desc: str = ""
'''Handler 的描述信息'''
extras_configs: dict = field(default_factory=dict)
'''插件注册的一些其他的信息, 如 priority 等'''
def __lt__(self, other: StarHandlerMetadata):
'''定义小于运算符以支持优先队列'''
return self.extras_configs.get('priority', 0) < other.extras_configs.get('priority', 0)

View File

@@ -1,20 +1,24 @@
import inspect
import functools
import os
import sys
import json
import traceback
import yaml
import logging
from types import ModuleType
from typing import List
from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
from .filter.permission import PermissionTypeFilter, PermissionType
class PluginManager:
def __init__(
@@ -25,12 +29,22 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
'''存储插件的路径。即 data/plugins'''
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
'''存储插件配置的路径。data/config'''
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
'''保留插件的路径。在 packages 目录下'''
self.conf_schema_fname = "_conf_schema.json"
'''插件配置 Schema 文件名'''
self.failed_plugin_info = ""
def _get_classes(self, arg: ModuleType):
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
for (name, _) in clsmembers:
@@ -89,21 +103,12 @@ class PluginManager:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查插件 {p} 的依赖: {pth}")
logger.info(f"正在安装插件 {p} 所需的依赖: {pth}")
try:
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _update_plugin_dept(self, path):
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend(self.config.pip_install_arg)
result_code = pip_main(args)
if result_code != 0:
raise Exception(str(result_code))
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
'''v3.4.0 以前的方式载入插件元数据
@@ -123,7 +128,7 @@ class PluginManager:
if isinstance(metadata, dict):
if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata:
raise Exception("插件元数据信息不完整。")
raise Exception("插件元数据信息不完整。name, desc, version, author 是必须的字段。")
metadata = StarMetadata(
name=metadata['name'],
author=metadata['author'],
@@ -133,29 +138,69 @@ class PluginManager:
)
return metadata
def reload(self):
'''扫描并加载所有的 Star'''
star_handlers_registry.clear()
async def reload(self, specified_plugin_name=None):
'''扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件'''
specified_module_path = None
if specified_plugin_name:
for smd in star_registry:
if smd.name == specified_plugin_name:
specified_module_path = smd.module_path
break
# 终止插件
if not specified_module_path:
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
await self._unbind_plugin(smd.name, specified_module_path)
try:
del sys.modules[specified_module_path]
except KeyError:
logger.warning(f"模块 {specified_module_path} 未载入")
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"
fail_rec = ""
# 导入 Star 模块,并尝试实例化 Star 类
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
# 导入插件模块,并尝试实例化插件类
for plugin_module in plugin_modules:
try:
module_str = plugin_module['module']
# module_path = plugin_module['module_path']
root_dir_name = plugin_module['pname']
reserved = plugin_module.get('reserved', False)
root_dir_name = plugin_module['pname'] # 插件的目录名
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str
if specified_module_path and path != specified_module_path:
continue
logger.info(f"正在载入插件 {root_dir_name} ...")
# 尝试导入模块
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str
try:
module = __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError):
@@ -166,28 +211,78 @@ class PluginManager:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
continue
# 检查 _conf_schema.json
plugin_config = None
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
if os.path.exists(plugin_schema_path):
# 加载插件配置
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
plugin_config = AstrBotConfig(
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
schema=json.loads(f.read())
)
if path in star_map:
# 通过装饰器的方式注册插件
star_metadata = star_map[path]
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
star_metadata.module = module
star_metadata.root_dir_name = root_dir_name
star_metadata.reserved = reserved
metadata = star_map[path]
try:
# yaml 文件的元数据优先
metadata_yaml = self._load_plugin_metadata(plugin_path=plugin_dir_path)
if metadata_yaml:
metadata.name = metadata_yaml.name
metadata.author = metadata_yaml.author
metadata.desc = metadata_yaml.desc
metadata.version = metadata_yaml.version
metadata.repo = metadata_yaml.repo
except Exception:
pass
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
except TypeError as _:
metadata.star_cls = metadata.star_cls_type(context=self.context)
else:
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
# 绑定 handler
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
for handler in related_handlers:
handler.handler = functools.partial(handler.handler, metadata.star_cls)
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == metadata.module_path:
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(func_tool.handler, metadata.star_cls)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
else:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
classes = self._get_classes(module)
try:
obj = getattr(module, classes[0])(context=self.context)
except BaseException as e:
logger.error(f"插件 {root_dir_name} 实例化失败。")
raise e
if plugin_config:
try:
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
except TypeError as _:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
else:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
metadata = self._load_plugin_metadata(plugin_path=plugin_dir_path, plugin_obj=obj)
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -195,27 +290,63 @@ class PluginManager:
metadata.module_path = path
star_map[path] = metadata
star_registry.append(metadata)
logger.debug(f"插件 {root_dir_name} 载入成功。")
# 禁用/启用插件
if metadata.module_path in inactivated_plugins:
metadata.activated = False
full_names = []
for handler in star_handlers_registry.get_handlers_by_module_name(metadata.module_path):
full_names.append(handler.handler_full_name)
# 检查并且植入自定义的权限过滤器alter_cmd
if metadata.name in alter_cmd and handler.handler_name in alter_cmd[metadata.name]:
cmd_type = alter_cmd[metadata.name][handler.handler_name].get("permission", "member")
found_permission_filter = False
for filter_ in handler.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if cmd_type == "admin":
filter_.permission_type = PermissionType.ADMIN
else:
filter_.permission_type = PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
handler.event_filters.append(PermissionTypeFilter(PermissionType.ADMIN if cmd_type == "admin" else PermissionType.MEMBER))
logger.debug(f"插入权限过滤器 {cmd_type}{metadata.name}{handler.handler_name} 方法。")
metadata.star_handler_full_names = full_names
# 执行 initialize() 方法
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
except BaseException as e:
traceback.print_exc()
fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n"
logger.error(f"----- 插件 {root_dir_name} 载入失败 -----")
errors = traceback.format_exc()
for line in errors.split('\n'):
logger.error(f"| {line}")
logger.error("----------------------------------")
fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {str(e)}\n"
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if not fail_rec:
return True, None
else:
self.failed_plugin_info = fail_rec
return False, fail_rec
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
self._check_plugin_dept_update()
async def install_plugin(self, repo_url: str, proxy=""):
plugin_path = await self.updator.install(repo_url, proxy)
# reload the plugin
await self.reload()
return plugin_path
def uninstall_plugin(self, plugin_name: str):
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -224,22 +355,81 @@ class PluginManager:
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
del star_map[plugin.module_path]
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
del star_map[plugin_module_path]
for i, p in enumerate(star_registry):
if p.name == plugin_name:
del star_registry[i]
break
for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path):
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
star_handlers_registry.remove(handler)
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)]
for k in keys_to_delete:
v = star_handlers_registry.star_handlers_map[k]
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
del star_handlers_registry.star_handlers_map[k]
async def update_plugin(self, plugin_name: str):
async def update_plugin(self, plugin_name: str, proxy = ""):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
if plugin.reserved:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
await self.updator.update(plugin, proxy=proxy)
await self.reload()
async def turn_off_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
inactivated_plugins.remove(plugin.module_path)
sp.put("inactivated_plugins", inactivated_plugins)
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = True
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
desti_dir = os.path.join(self.plugin_store_path, dir_name)
self.updator.unzip_file(zip_file_path, desti_dir)
# remove the zip
@@ -247,5 +437,4 @@ class PluginManager:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
await self.reload()

View File

@@ -15,20 +15,24 @@ class PluginUpdator(RepoZipUpdator):
def get_plugin_store_path(self) -> str:
return self.plugin_store_path
async def install(self, repo_url: str) -> str:
async def install(self, repo_url: str, proxy="") -> str:
repo_name = self.format_repo_name(repo_url)
plugin_path = os.path.join(self.plugin_store_path, repo_name)
await self.download_from_repo_url(plugin_path, repo_url)
await self.download_from_repo_url(plugin_path, repo_url, proxy)
self.unzip_file(plugin_path + ".zip", plugin_path)
return plugin_path
async def update(self, plugin: StarMetadata) -> str:
async def update(self, plugin: StarMetadata, proxy="") -> str:
repo_url = plugin.repo
if not repo_url:
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
if proxy:
proxy = proxy.removesuffix("/")
repo_url = f"{proxy}/{repo_url}"
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
@@ -53,7 +57,6 @@ class PluginUpdator(RepoZipUpdator):
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -63,7 +66,7 @@ class PluginUpdator(RepoZipUpdator):
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
logger.info(f"删除临时文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:

View File

@@ -11,7 +11,7 @@ class AstrBotUpdator(RepoZipUpdator):
def __init__(self, repo_mirror: str = "") -> None:
super().__init__(repo_mirror)
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
def terminate_child_processes(self):
try:

View File

@@ -0,0 +1,173 @@
import json
from astrbot.core import logger
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession()
self.headers = {
"Authorization": f"Bearer {self.api_key}",
}
async def chat_messages(
self,
inputs: Dict,
query: str,
user: str,
response_mode: str = "streaming",
conversation_id: str = "",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
) -> AsyncGenerator[Dict[str, Any], None]:
url = f"{self.api_base}/chat-messages"
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"chat_messages payload: {payload}")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode('utf-8')
blocks = buffer.split('\n\n')
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith('data:'):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
async def workflow_run(
self,
inputs: Dict,
user: str,
response_mode: str = "streaming",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
):
url = f"{self.api_base}/workflows/run"
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"workflow_run payload: {payload}")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode('utf-8')
blocks = buffer.split('\n\n')
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith('data:'):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
async def file_upload(
self,
file_path: str,
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
async def get_chat_convs(
self,
user: str,
limit: int = 20
):
# conversations. GET
url = f"{self.api_base}/conversations"
payload = {
"user": user,
"limit": limit,
}
async with self.session.get(
url, params=payload, headers=self.headers
) as resp:
return await resp.json()
async def delete_chat_conv(
self,
user: str,
conversation_id: str
):
# conversation. DELETE
url = f"{self.api_base}/conversations/{conversation_id}"
payload = {
"user": user,
}
async with self.session.delete(
url, json=payload, headers=self.headers
) as resp:
return await resp.json()
async def rename(
self,
conversation_id: str,
name: str,
user: str,
auto_generate: bool = False
):
# /conversations/:conversation_id/name
url = f"{self.api_base}/conversations/{conversation_id}/name"
payload = {
"user": user,
"name": name,
"auto_generate": auto_generate,
}
async with self.session.post(
url, json=payload, headers=self.headers
) as resp:
return await resp.json()

View File

@@ -5,6 +5,10 @@ import socket
import time
import aiohttp
import base64
import zipfile
import uuid
import psutil
from typing import Union
from PIL import Image
@@ -40,21 +44,21 @@ def port_checker(port: int, host: str = "localhost"):
return False
def save_temp_img(img: Image) -> str:
def save_temp_img(img: Union[Image.Image, str]) -> str:
os.makedirs("data/temp", exist_ok=True)
# 获得文件创建时间,清除超过1小时的
# 获得文件创建时间,清除超过 12 小时的
try:
for f in os.listdir("data/temp"):
path = os.path.join("data/temp", f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600:
if time.time() - ctime > 3600*12:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = int(time.time())
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = f"data/temp/{timestamp}.jpg"
if isinstance(img, Image.Image):
@@ -64,23 +68,33 @@ def save_temp_img(img: Image) -> str:
f.write(img)
return p
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
'''
下载图片, 返回 path
'''
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
if post:
async with session.post(url, json=post_data) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
except aiohttp.client_exceptions.ClientConnectorSSLError:
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers('DEFAULT')
async with aiohttp.ClientSession(trust_env=False) as session:
async with aiohttp.ClientSession() as session:
if post:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
@@ -90,36 +104,91 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
except Exception as e:
raise e
async def download_file(url: str, path: str):
async def download_file(url: str, path: str, show_progress: bool = False):
'''
从指定 url 下载文件到指定路径 path
'''
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get('content-length', 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
except Exception as e:
raise e
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers('DEFAULT')
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get('content-length', 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
if show_progress:
print()
def file_to_base64(file_path: str) -> str:
with open(file_path, "rb") as f:
data_bytes = f.read()
base64_str = base64.b64encode(data_bytes).decode()
return "base64://" + base64_str
def get_local_ip_addresses():
ip = ''
net_interfaces = psutil.net_if_addrs()
network_ips = []
for interface, addrs in net_interfaces.items():
for addr in addrs:
if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET
network_ips.append(addr.address)
return network_ips
async def get_dashboard_version():
if os.path.exists("data/dist"):
if os.path.exists("data/dist/assets/version"):
with open("data/dist/assets/version", "r") as f:
v = f.read().strip()
return v
return None
async def download_dashboard():
'''下载管理面板文件'''
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
except BaseException:
pass
finally:
s.close()
return ip
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
except BaseException as _:
dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
print("解压管理面板文件中...")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")

Some files were not shown because too many files have changed in this diff Show More