Compare commits

..

228 Commits

Author SHA1 Message Date
Soulter
16ec462abd feat: WebUI ProviderPage 添加服务提供商会话隔离设置功能 2025-06-11 00:51:18 +08:00
Soulter
ca55465d3c chore: bump to 3.5.15 2025-06-11 00:32:46 +08:00
Soulter
7098c98dde fix: 修复 Windows 下部署项目时可能出现的 UnicodeDecodeError
fixes: #1548
2025-06-11 00:25:14 +08:00
Soulter
f56355da89 perf: 分段回复时,仅在输出的第一句话带上回复/引用
fixes: #521
2025-06-11 00:06:14 +08:00
Soulter
422160debd feat: 支持配置是否忽略@全体成员
fixes: #292
2025-06-10 23:55:50 +08:00
Soulter
8062cf406a fix: 优化配置完整性检查,同时保证配置项顺序的一致性 2025-06-10 23:30:58 +08:00
Soulter
0e802232ec feat: 新配置项,支持配置只@触发等待时是否回复 2025-06-10 23:29:45 +08:00
Soulter
f650a9205d perf(webui): 优化手机端的显示 2025-06-10 22:43:58 +08:00
Soulter
c85dbb2347 fix: 修复某些情况下,会话控制无效的问题 2025-06-10 22:26:11 +08:00
Soulter
a6a79128c8 chore: bump to v3.5.15 2025-06-10 22:18:05 +08:00
Soulter
42839627e8 fix: 修复在设置了 GitHub 加速地址后,插件无法更新的问题 2025-06-10 22:12:46 +08:00
Soulter
267e68a894 chore: bump docker image python version to 3.11 2025-06-10 21:40:20 +08:00
Soulter
b32b444438 Merge pull request #1776 from AstrBotDevs/feat-webchat-title
Feature: 支持重命名和自动生成 WebChat title;WebChat Route 和 UI 优化;支持 WebChatBox
2025-06-10 21:34:17 +08:00
Soulter
522d0f8313 chore: ts lint 2025-06-10 21:33:53 +08:00
Soulter
5715e5de67 chore: fix ts lint 2025-06-10 21:28:06 +08:00
Soulter
cc6b05e8b3 fix: remove fallback for returnUrl in AuthLogin.vue 2025-06-10 21:25:58 +08:00
Soulter
417747d5d0 feat: handle unauthorized access by redirecting to login page in ChatPage 2025-06-10 21:21:38 +08:00
Soulter
a34f439226 fix: update summary output condition and adjust max-width in ChatBoxPage 2025-06-10 18:36:26 +08:00
Soulter
b7ca014fd0 feat: enhance routing to support chatbox and improve path handling in ChatPage 2025-06-10 15:45:06 +08:00
Soulter
fa098d585a feat: add conversation detail routing and handle direct navigation in ChatPage 2025-06-10 15:39:26 +08:00
Soulter
c35a14e3ec fix: adjust padding and clean up unused code in ChatPage.vue 2025-06-10 15:06:33 +08:00
Soulter
60651736a5 feat: chatbox page 2025-06-10 15:02:18 +08:00
Soulter
581f9b7bd3 fix: typo fix
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-10 13:02:30 +08:00
Soulter
124eb04807 Merge pull request #1773 from AstrBotDevs/feat-seperate-provider
Feature: 支持对提供商会话隔离
2025-06-10 12:59:42 +08:00
Soulter
1d561da7fb style: clean code
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-06-10 12:59:20 +08:00
Soulter
16e3cd0784 fix: get_using_stt_provider is fetching using ProviderType.TEXT_TO_SPEECH but should use ProviderType.SPEECH_TO_TEXT for STT isolation.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-06-10 12:58:39 +08:00
Soulter
a6d91933dc feat: 支持自动生成webchat title 2025-06-10 10:58:49 +08:00
Raven95676
445c40f758 chore: update version 2025-06-10 10:29:31 +08:00
鸦羽
725a841a3b Merge pull request #1767 from AstrBotDevs/fix/1678
Fix: 调整Gemini原生工具启用行为
2025-06-10 08:22:41 +08:00
鸦羽
f77c453843 fix: clean code 2025-06-10 00:20:35 +00:00
Soulter
ba6718d5bc Merge pull request #1759 from Flartiny/dev
Feature: Add GreedyStr parameter support for commands
2025-06-10 00:06:34 +08:00
Soulter
cdb7a1b3fa style: merge else if into elif 2025-06-09 23:54:51 +08:00
Soulter
a03c79b89d style: use named expression 2025-06-09 23:51:54 +08:00
Soulter
98800d3426 fix(typo): "seperate_provider" -> "separate_provider" 2025-06-09 23:50:31 +08:00
Soulter
a616adaac4 fix: update provider manager set_provider() 2025-06-09 23:46:44 +08:00
Soulter
ffb5605c99 fix: default tts provider selection
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-09 23:38:15 +08:00
Soulter
621b556856 feat: 支持对提供商会话隔离
fixes: #1762 #602 #479
2025-06-09 23:33:00 +08:00
Soulter
a3ffecbb2a feat: add support for gemini_embedding provider 2025-06-09 14:43:05 +08:00
Soulter
ea64cebe2a ci: fix cloudflare r2 ci 2025-06-09 13:12:31 +08:00
鸦羽
e79487dd5f fix: add missing config 2025-06-09 05:03:15 +00:00
鸦羽
7fe1c1ec89 feat: add URL context feature to Gemini model configuration 2025-06-09 04:54:24 +00:00
Soulter
ab2bbff369 Merge pull request #1746 from Seayon/fix-wechat-at-message-parsing
 feat(wechatpadpro): 增强群聊消息中的@消息处理逻辑
2025-06-09 12:51:08 +08:00
Soulter
ec32825309 ci: fix cloudflare r2 upload 2025-06-09 12:41:20 +08:00
Soulter
fd0c182087 ci: fix ghcr token 2025-06-09 12:32:38 +08:00
Soulter
49fcff1daf 📦 release: v3.5.14 2025-06-09 12:31:02 +08:00
鸦羽
33b64ddf39 feat: enhance tool selection logic for Gemini model versions 2025-06-09 03:55:59 +00:00
Soulter
4c447aa648 perf: jwt token expire time change to 7 days 2025-06-09 11:52:48 +08:00
Soulter
ccbfc3d274 perf: 强化强制修改默认密码逻辑 2025-06-09 11:47:23 +08:00
Soulter
f83fe43bbb docs: alert 2025-06-09 10:12:09 +08:00
Seayon
19022d67f8 Merge branch 'master' into fix-wechat-at-message-parsing
# Conflicts:
#	astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py
2025-06-09 09:30:09 +08:00
Soulter
58a815dd6b feat: ltm edge fact viewer 2025-06-08 20:34:41 +08:00
Soulter
bc9fe82860 Merge pull request #1737 from zhx8702/feat-wehcatpro-voice-adapter
feat: wechatpadpro 添加语音接收和发送的适配
2025-06-07 15:13:10 +08:00
Soulter
b3cd9bf2b9 Merge pull request #1743 from lvboda/hotfix-platform-page-iframe-style-issue-1741
fix(PlatformPage): iframe overflow style issue (#1741)
2025-06-07 15:11:16 +08:00
Soulter
c5c2b829ec Merge pull request #1758 from RC-CHN/master
fix: 修复 asyncio.wait_for 参数顺序错误
2025-06-07 15:08:37 +08:00
Flartiny
9713f96401 feat: Add greedy parameter support for commands 2025-06-07 10:32:31 +08:00
Ruochen
11f35ebf96 fix: 修复 asyncio.wait_for 参数顺序错误 2025-06-07 09:50:30 +08:00
Soulter
7d403aa181 fix: syntax error 2025-06-07 01:20:56 +08:00
Soulter
64af810a4a Merge pull request #1736 from RC-CHN/master
fix:修复了部分模型供应商测试不可用,但实际可用的问题。
2025-06-06 21:37:19 +08:00
Soulter
30821905af perf: remove default list param,fix dashscope_source contexts params 2025-06-06 21:36:01 +08:00
Seayon
a9dbff756b feat(wechatpadpro): 增强群聊消息中的@消息处理逻辑
添加对群聊消息中@机器人场景的精确识别和处理,提升了消息解析的准确性。
支持多种@格式的检测,包括 msg_source 和 push_content 的判断。
2025-06-06 16:53:31 +08:00
lvboda
a6aba10d3d fix(PlatformPage): iframe overflow style issue (#1741) 2025-06-06 15:18:35 +08:00
RC-CHN
9c276c37fe Update astrbot/dashboard/routes/config.py
测试过对于dashscope类型供应商添加上下文是必要的,否则需要改动其_remove_image_from_context方法。

Co-authored-by: Soulter  <37870767+Soulter@users.noreply.github.com>
2025-06-06 14:01:58 +08:00
Soulter
6ab6c0fd4c Merge pull request #1735 from Flartiny/dev
feat: able to parse repo url of specific branch
2025-06-06 12:44:51 +08:00
Soulter
b6b0fe3fff perf: 优化 GitHub 仓库解析和下载的逻辑 2025-06-06 12:02:46 +08:00
zhx
0d5825bda9 feat: wechatpadpro 添加语音接收和发送的适配 2025-06-06 10:30:06 +08:00
Ruochen
cdfb64631a fix:修复dashscope类型供应商测试问题,延长了设置超时时间,改进prompt工程,修复了控制台打印日志超时时间不符 2025-06-06 09:21:09 +08:00
Ruochen
d161c281c8 Merge branch 'master' of https://github.com/RC-CHN/AstrBot 2025-06-06 00:39:25 +08:00
Flartiny
8fed5bf2a1 feat: able to parse repo url of specific branch 2025-06-06 00:09:10 +08:00
Soulter
98d2e9bd27 chore: stage 2025-06-05 23:30:18 +08:00
Soulter
a03af55edd ci 2025-06-05 13:38:20 +08:00
Soulter
86e2fd9aee ci: publish to ghcr.io 2025-06-05 13:35:14 +08:00
Soulter
97bd0e5e58 Merge pull request #1730 from lxfight/master
feat: 添加插件更新后自动刷新插件列表功能
2025-06-05 11:39:32 +08:00
Soulter
ceaba21986 ci: publish to ghcr.io 2025-06-05 11:19:16 +08:00
Soulter
172a77d942 ci: publish to ghcr.io 2025-06-05 11:16:57 +08:00
Soulter
4f9d2d2a7d ci: publish to ghcr.io 2025-06-05 11:12:56 +08:00
lxfight
8c929f6e05 feat: 添加插件更新后自动刷新插件列表功能 2025-06-05 10:56:04 +08:00
Soulter
3319b71f5b Merge pull request #1721 from zhx8702/feat-add-wechat-47-49
feat: 添加wechatpadpro 消息类型47 49的适配
2025-06-04 22:52:29 +08:00
Soulter
46ec028a5b Merge pull request #1718 from Kwicxy/webui_enhancement
feat: webUI优化
2025-06-04 22:48:49 +08:00
Soulter
0ce0ef3e5c Merge pull request #1715 from Flartiny/dev
fix: residual configuration items after plugin configuration modification
2025-06-04 22:32:19 +08:00
kwicxy
375b071cb2 Merge remote-tracking branch 'origin/webui_enhancement' into webui_enhancement 2025-06-04 19:00:54 +08:00
kwicxy
29e1417ff2 feat: optional newUsername field in account editing 2025-06-04 18:59:38 +08:00
kwicxy
75db2bd366 fix(auth): bad localStorage keymapping 2025-06-04 18:58:53 +08:00
zhx
60ca1efbda feat: 添加wechatpadpro 消息类型47 49的适配 2025-06-04 14:36:16 +08:00
Richard X.
2692e4978b fix: remove console.log()
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-06-03 21:06:51 +08:00
Richard X.
91982eb002 Merge branch 'AstrBotDevs:master' into webui_enhancement 2025-06-03 20:36:51 +08:00
Soulter
bb1dec76fa remove: wechat qr code
hahaha
2025-06-03 20:22:08 +08:00
Flartiny
f618b8fcdc fix: residual configuration items after plugin configuration modification 2025-06-03 14:04:04 +08:00
Raven95676
9147cab75b fix: add additional routes for Alkaid knowledge base and long-term memory 2025-05-31 14:29:04 +08:00
Raven95676
5f07bcc8e6 feat: add Gemini embedding provider and update OpenAI provider to support timeout configuration 2025-05-31 14:13:58 +08:00
Soulter
705cf2ea1b docs(README.md): knowledge base 2025-05-31 14:08:01 +08:00
Soulter
42c4394484 ci: upload dashboard artifact to Cloudflare R2 when auto release 2025-05-31 13:50:40 +08:00
Soulter
221221a3c1 ci: upload dashboard artifact to Cloudflare R2 when auto release 2025-05-31 13:47:59 +08:00
Soulter
9564166297 perf: knowledge base displays console when installing 2025-05-31 11:52:24 +08:00
Soulter
f5cf3c3c8e Merge pull request #1691 from AstrBotDevs/perf-pip-async
Feature: 将插件依赖检查和 pip 安装方法改为异步,以提高性能和响应速度
2025-05-31 11:51:39 +08:00
Soulter
18f919fb6b perf: pip_main wrapped in asyncio.to_thread
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-05-31 11:47:29 +08:00
Soulter
0924835253 feat: 将插件依赖检查和 pip 安装方法改为异步,以提高性能和响应速度 2025-05-31 11:44:58 +08:00
Soulter
20d2e5c578 perf: 优化日志流发送频率,防止积压超过 buffer size 导致前端显示异常 2025-05-31 11:25:51 +08:00
Soulter
907801605c 📦 release: v3.5.13 2025-05-31 11:02:56 +08:00
Soulter
93bc684e8c feat: 添加旧版本提供商类型映射以兼容性支持 2025-05-31 11:00:59 +08:00
Soulter
a76c98d57e Merge pull request #1685 from RC-CHN/master
Feature: 添加测试文本生成供应商可用功能
2025-05-31 10:59:46 +08:00
Soulter
d937a800d0 fix: provider name 2025-05-31 10:46:35 +08:00
Soulter
d16f3a227f Merge branch 'master' into master 2025-05-31 10:46:15 +08:00
Soulter
80c9a3eeda style: code style 2025-05-31 09:25:18 +08:00
Soulter
e68173b451 feat: knowledge-base 2025-05-30 23:18:48 +08:00
Soulter
40c27d87f5 feat: knowledge-base 2025-05-30 23:18:19 +08:00
Soulter
3c13b5049d feat: 支持知识库的分片、重叠设置等 2025-05-30 23:00:37 +08:00
Soulter
8288d5e51f feat: embedding provider 2025-05-30 18:07:52 +08:00
Ruochen
6e1449900a feat: 优化单个 provider 可用性测试的回退逻辑 2025-05-30 15:35:13 +08:00
RC-CHN
4ffbb18ab4 Merge branch 'AstrBotDevs:master' into master 2025-05-30 15:12:33 +08:00
Ruochen
b27271b7a3 feat:添加测试文本生成供应商可用功能 2025-05-30 15:10:15 +08:00
Soulter
ebb6665f64 feat: add open_config parameter handling and configuration button in KnowledgeBase 2025-05-30 14:30:04 +08:00
Soulter
e4e5731ffd 📦 release: v3.5.13 2025-05-30 13:30:23 +08:00
Soulter
2ab5810f13 perf: improve transaction performance in vector db 2025-05-30 12:59:26 +08:00
Soulter
af934c5d09 fix: correct dimension typo and enhance API registration logic 2025-05-30 11:42:39 +08:00
Soulter
1e0cf7c112 fix: update ExtensionCard actions and add readme link functionality 2025-05-30 10:50:54 +08:00
Soulter
46859c93c9 perf: improve WebUI 2025-05-30 10:45:05 +08:00
Richard X.
ea1f9cb3b2 Merge branch 'AstrBotDevs:master' into master 2025-05-30 10:37:59 +08:00
Soulter
1641549016 perf: improve WebUI 2025-05-30 10:36:48 +08:00
鸦羽
716a5dbb8a chore: add nh3 to requirements.txt 2025-05-30 10:35:48 +08:00
鸦羽
af98cb11c5 fix: handle missing nh3 library in plugin.py 2025-05-30 10:35:48 +08:00
Soulter
9a4c2cf341 fix: downgrade faiss-cpu dependency to version 1.10.0 2025-05-30 10:21:31 +08:00
Soulter
2bc3bcd102 fix: handle missing nh3 library gracefully for README cleaning 2025-05-30 10:17:33 +08:00
Soulter
d6c663f79d fix: do not display change password dialog in demo mode 2025-05-30 10:09:09 +08:00
kwicxy
9ed86e5f53 feat: Name trim of extension list to improve readability 2025-05-30 09:37:21 +08:00
kwicxy
303e0bc037 fix(dashboard): MessageStat chart tooltips now supports dark appearance 2025-05-30 09:36:06 +08:00
Richard X.
2cc24019f9 Merge branch 'AstrBotDevs:master' into master 2025-05-30 08:50:27 +08:00
kwicxy
83ce774d19 chore: Extension marketplace scroll behaviour updated 2025-05-30 00:01:53 +08:00
Soulter
2b4ee13b5e Merge pull request #1672 from Kwicxy/master
Feat: 暗黑主题功能初步实现
2025-05-29 23:41:10 +08:00
kwicxy
3a964561f0 style: minor code style changes 2025-05-29 22:57:50 +08:00
kwicxy
6959f86632 feat: Using localStorage to remember user's theme setting. 2025-05-29 22:46:02 +08:00
Raven95676
537d373e10 fix: Fix potential XSS risk in plugin README content 2025-05-29 22:35:24 +08:00
Soulter
cceadf222c Merge pull request #1676 from AstrBotDevs/fix-chat-get-file-bug
Fix: fixed a potential vulnerability in `/api/chat/get_file` endpoint.
2025-05-29 21:41:55 +08:00
Soulter
cf5a4af623 chore: remove duplicated auth header 2025-05-29 21:19:39 +08:00
Raven95676
39aea11c22 perf: enhance file access security in get_file method
Co-authored-by: anka-afk <1350989414@qq.com>
2025-05-29 21:03:51 +08:00
Raven95676
c2f1227700 fix: add authorization header to file download request in ChatPage.vue 2025-05-29 19:57:11 +08:00
Soulter
900f14d37c 🐛 fix: fixed a potential vulnerability in /api/chat/get_file endpoint.
I have fixed a potential vulnerability in the `/api/chat/get_file` endpoint that could allow unauthorized access to files by ensuring the request has a jwt token.
2025-05-29 19:17:31 +08:00
kwicxy
598249b1d6 Merge remote-tracking branch 'origin/master' 2025-05-29 18:26:53 +08:00
Richard X.
7ed15bdf04 Merge branch 'AstrBotDevs:master' into master 2025-05-29 18:17:39 +08:00
Raven95676
2fc0ec0f72 fix: update route 2025-05-29 17:28:33 +08:00
kwicxy
5e9c2a669b fix: Various bug fixes and improvements 2025-05-29 16:41:03 +08:00
Soulter
b310521884 📦 release: v3.5.12 2025-05-29 15:55:25 +08:00
Soulter
288945bf7e chore: aiosqlite to requirements.txt 2025-05-29 15:48:21 +08:00
Soulter
4fc07cff36 📦 release: v3.5.12 2025-05-29 15:46:40 +08:00
kwicxy
b884fe0e86 fix: Various bug fixes 2025-05-29 09:31:29 +08:00
kwicxy
855858c236 fix: Changed default theme to PurpleTheme 2025-05-29 09:31:15 +08:00
kwicxy
c11a2a5419 feat: Login page darkened 2025-05-29 09:00:27 +08:00
kwicxy
773a6572af feat: WebUI Dark Appearance 2025-05-29 01:43:21 +08:00
kwicxy
88ad373c9b 深色主题切换功能初步实现 2025-05-29 01:28:45 +08:00
Soulter
51666464b9 Merge pull request #1667 from AstrBotDevs/fix-priority
Fix: plugin priority was not properly applied
2025-05-28 15:34:50 +08:00
Soulter
5af9cf2f52 Merge pull request #1668 from AstrBotDevs/refactor-segment
Refactor: 重构转发节点等消息段的 toDict 相关逻辑
2025-05-28 15:33:32 +08:00
Soulter
12c4ae4b10 perf: to_dict in the base class 2025-05-28 03:26:42 -04:00
Soulter
4e1bef414a perf: empty array 2025-05-28 03:25:19 -04:00
Soulter
e896c18644 perf: video 2025-05-28 15:12:21 +08:00
Soulter
c852685e74 fix: typeerror 2025-05-28 01:18:45 -04:00
Soulter
1e99797df8 refactor: improve message segment handle 2025-05-28 12:53:00 +08:00
Soulter
52a4c986a8 fix: update star_handlers_registry iteration in TelegramPlatformAdapter 2025-05-28 00:31:04 +08:00
Soulter
c501728204 fix: plugin priority
fixes: #1662
2025-05-28 00:23:02 +08:00
Soulter
6b067fa6a7 Merge pull request #1665 from Raven95676/master
fix(telegram): 支持长消息分段发送并优化消息编辑逻辑
2025-05-27 23:39:14 +08:00
Soulter
a1cd5c53a9 chore: add comments 2025-05-27 23:38:35 +08:00
Soulter
a46d487e03 Merge pull request #1644 from RC-CHN/master
fix:为llm和model和provider指令添加了管理员权限检查
2025-05-27 23:25:40 +08:00
Raven95676
3deb6d3ab3 fix: clean code 2025-05-27 20:52:40 +08:00
Raven95676
af34cdd5d2 fix(telegram): 支持长消息分段发送并优化消息编辑逻辑 2025-05-27 20:15:16 +08:00
Soulter
6e1393235a 🐛 fix: provider command error 2025-05-27 17:20:57 +08:00
Soulter
343e0b54b9 feat: MCP supports Streamable HTTP transport method
fixes: #1637 #1342
2025-05-27 15:39:02 +08:00
Soulter
ecb70cb6f7 feat: add support for custom headers in SSE client configuration
fixes: #1659
2025-05-27 15:05:42 +08:00
Soulter
ca50618af6 perf: load providers when llm config is off and rebooting astrbot
fixes: #1466
2025-05-27 15:01:58 +08:00
Soulter
29c07ba83e 🐛 fix: function tools argument type issue
fixes: #1454
2025-05-27 13:54:16 +08:00
Ruochen
45fbb83a9f fix:为llm和model和provider指令添加了管理员权限检查 2025-05-25 00:24:20 +08:00
Soulter
ae7ba2df25 Merge pull request #1553 from Raven95676/Feature/use-file-service
Feature: T2I、TTS使用文件服务
2025-05-23 17:10:38 +08:00
Soulter
c3ef57cc32 Merge pull request #1588 from Zhenyi-Wang/feat/extend-wechatpadpro-for-timetask
feat: wechatpadpro对接获取联系人信息的2个接口
2025-05-23 17:02:54 +08:00
Soulter
7bb4ca5a14 perf: code quality 2025-05-23 17:01:57 +08:00
Soulter
063783d81d Merge pull request #1599 from HendricksJudy/master
Fix initialization bug and improve plugin utility
2025-05-23 16:58:25 +08:00
Soulter
42116c9b65 Merge pull request #1631 from AstrBotDevs/feat/alkaid
[WIP] Feature: 提供 AstrBot 后端服务插件接口、试验性嵌入式知识库(Alkaid)、移除不必要的包
2025-05-23 16:57:04 +08:00
Soulter
a36e11973d perf: code quality
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:56:09 +08:00
Soulter
5125568ea2 perf: 交换 if/else 表达式的分支以删除否定
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:49:08 +08:00
Soulter
0fa164e50d perf: 使用 HTML autocomplete 属性禁用浏览器自动填充
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-05-23 16:48:29 +08:00
Soulter
cf814e81ee chore: delete alkaid route 2025-05-23 16:41:33 +08:00
Soulter
43a45f18ce perf: knowledgebase delete 2025-05-23 15:50:10 +08:00
Soulter
ad51381063 perf: 动态路由注册 2025-05-23 15:18:16 +08:00
Soulter
0b0e4ce904 remove: vpet 2025-05-23 14:22:34 +08:00
Soulter
6a3e04d688 Merge remote-tracking branch 'origin/master' into feat/alkaid 2025-05-23 14:22:06 +08:00
Soulter
4107a17370 chore: add faiss and aiosqlite deps 2025-05-23 14:04:13 +08:00
Soulter
06b4d8f169 perf: vecdb similarity type 2025-05-23 13:45:00 +08:00
Soulter
1c0c820746 remove: loguru 2025-05-23 13:42:17 +08:00
Soulter
d061403a28 remove: loguru 2025-05-23 13:39:20 +08:00
Soulter
5c092321a6 feat: faiss vecdb implementation
remove: old knowledgedb deps
2025-05-23 13:16:24 +08:00
Soulter
bdd3f61c1f remove: old knowledge db impl and useless impls 2025-05-23 11:43:26 +08:00
Raven95676
8023557d6e feat: 强制修改默认密码 2025-05-22 18:30:29 +08:00
Raven95676
074b0ced7a perf: 移除冗余逻辑
经与@Soulter确认,metadata.yaml是必须有的文件,故在建议下删除
2025-05-22 18:21:41 +08:00
Soulter
3864b1ac9b Merge pull request #1620 from YOOkoishi/feat-add-volcengine-support
🐛 fix : 修改description,适配火山引擎基础的语音合成
2025-05-22 17:52:39 +08:00
YOO_koishi
6e9b43457d Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-22 08:09:59 +08:00
YOO_koishi
ca1aec8920 🐛 fix : 修改description,适配火山引擎基础的语音合成 2025-05-22 08:09:36 +08:00
Soulter
acac580862 feat: ltm and kb 2025-05-20 20:50:22 +08:00
Soulter
673e1b2980 remove: vpet 2025-05-20 15:03:40 +08:00
Soulter
f62157be72 📦 release: v3.5.11 2025-05-20 02:00:54 -04:00
Soulter
f894ecf3b6 Merge pull request #1592 from YOOkoishi/feat-add-volcengine-support
 feat: add volcengine support
2025-05-20 13:58:44 +08:00
Soulter
66dd4e28ad Merge pull request #1604 from Siztas/fix-refresh-device-when-login-WeChatPadPro
fix:修复了WeChatPadPro在重新登录时为新设备的问题,延长初始化Auth_Key有效期至365天
2025-05-20 13:57:40 +08:00
YOO_koishi
939dc1b0fb Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-20 13:52:03 +08:00
YOO_koishi
56bf5d38a1 🔧fix: 修改logger输出等级为debug级别 2025-05-20 13:51:11 +08:00
Soulter
d09b70b295 fix: 修复微信公众号(个人认证)下无法回复消息的问题 2025-05-20 01:38:13 -04:00
MiSeya
205180387a Fix:修复了WeChatPadPro在重新登录时为新设备的问题,延长初始化Auth_Key有效期至365天 2025-05-19 21:12:09 +08:00
HendricksJudy
39c8cfeda5 Merge pull request #2 from HendricksJudy/codex/fix-core-initialization-failure-handling-in-initialloader
Fix initialization bug and improve plugin utility
2025-05-19 01:43:22 -07:00
HendricksJudy
f38a329be5 Fix initialization and plugin download 2025-05-19 01:43:07 -07:00
YOO_koishi
a0cd069539 Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-19 16:17:43 +08:00
YOO_koishi
bf306a2f01 🩹fix: 修改添加logger函数,添加speed_ratio选项,为一些选项添加description 2025-05-19 16:16:25 +08:00
Soulter
c31f93a8d1 Merge pull request #1595 from HendricksJudy/master
Fix lint issues and highlight typos
2025-05-19 09:29:02 +08:00
HendricksJudy
4730ab6309 Merge pull request #1 from HendricksJudy/codex/find-bugs-or-typos
Fix lint issues and highlight typos
2025-05-18 02:31:17 -07:00
HendricksJudy
1ae78ca98c chore: fix lint issues 2025-05-18 02:30:31 -07:00
Soulter
d2379da478 chore: use d3 2025-05-18 16:43:47 +08:00
Soulter
0f64981b20 feat: alkaid long term memory graph visualize 2025-05-18 13:26:44 +08:00
YOO_koishi
0002e49bb5 Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-18 03:20:05 +08:00
YOO_koishi
db13a60274 feat: add-volcengine-tts-support 2025-05-18 03:18:36 +08:00
Soulter
db0f11a359 Merge pull request #1589 from Larch-C/master
🎈 perf: 优化了登录界面,解决了登录未自行跳转的问题
2025-05-17 21:40:14 +08:00
Soulter
ac7f43520b 🎈 perf: adjust login input padding style 2025-05-17 21:30:05 +08:00
Larch-C
f67b9f5f6e 🐞 fix: 解决了如果此前已经登录但未自行跳转的问题 2025-05-17 18:09:49 +08:00
Larch-C
c75156c4ce 🎈 perf: 优化了登录界面样式 2025-05-17 18:08:55 +08:00
Soulter
10270b5595 feat: alkaid framework and supports to customize webapi endpoint 2025-05-17 15:38:51 +08:00
Zhenyi Wang
f7458572ed feat: wechatpadpro对接获取联系人信息接口 2025-05-17 15:31:12 +08:00
Soulter
d57b7222b2 perf: 优化 WebUI About 页面、侧边栏和顶栏 2025-05-17 13:30:33 +08:00
Soulter
62e70a673a perf: 优化 Gemini 报错提示 2025-05-17 12:04:36 +08:00
Soulter
5e9eba6478 fix: extension market plugin card cannot apply installation 2025-05-16 22:43:38 -04:00
Raven95676
c5ccc1a084 feat(Video): 增加视频消息组件的文件转换和注册功能 2025-05-15 09:50:27 +08:00
YOO_koishi
6439917cbe Merge branch 'master' of https://github.com/AstrBotDevs/AstrBot into feat-add-volcengine-support 2025-05-14 22:45:02 +08:00
YOO_koishi
d21c18f657 change defualt.py 2025-05-14 22:43:40 +08:00
Raven95676
e6981290bc perf: 优化 Record 对象的文件和 URL 字段赋值逻辑 2025-05-14 20:05:38 +08:00
Raven95676
75c3d8abbd feat(t2i): 为本地文本转图像功能添加文件服务支持 2025-05-14 19:28:23 +08:00
Raven95676
d88683f498 feat(tts): 增加使用文件服务提供 TTS 语音文件的功能 2025-05-14 19:28:23 +08:00
Raven95676
40b9aa3a4c style: format code 2025-05-14 19:15:13 +08:00
123 changed files with 7433 additions and 1618 deletions

View File

@@ -23,6 +23,36 @@ jobs:
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Upload to Cloudflare R2
env:
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
R2_BUCKET_NAME: "astrbot"
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
VERSION_TAG: ${{ github.ref_name }}
run: |
echo "Installing rclone..."
curl https://rclone.org/install.sh | sudo bash
echo "Configuring rclone remote..."
mkdir -p ~/.config/rclone
cat <<EOF > ~/.config/rclone/rclone.conf
[r2]
type = s3
provider = Cloudflare
access_key_id = $R2_ACCESS_KEY_ID
secret_access_key = $R2_SECRET_ACCESS_KEY
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
EOF
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
- name: Fetch Changelog
run: |

View File

@@ -11,24 +11,42 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: 拉取源码
- name: Pull The Codes
uses: actions/checkout@v3
with:
fetch-depth: 1
fetch-depth: 0 # Must be 0 so we can fetch tags
- name: 设置 QEMU
- name: Get latest tag (only on manual trigger)
id: get-latest-tag
if: github.event_name == 'workflow_dispatch'
run: |
tag=$(git describe --tags --abbrev=0)
echo "latest_tag=$tag" >> $GITHUB_OUTPUT
- name: Checkout to latest tag (only on manual trigger)
if: github.event_name == 'workflow_dispatch'
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
- name: Set QEMU
uses: docker/setup-qemu-action@v3
- name: 设置 Docker Buildx
- name: Set Docker Buildx
uses: docker/setup-buildx-action@v3
- name: 登录到 DockerHub
- name: Log in to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: 构建和推送 Docker hub
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: Soulter
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Docker to DockerHub and Github GHCR
uses: docker/build-push-action@v6
with:
context: .
@@ -36,8 +54,9 @@ jobs:
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
ghcr.io/soulter/astrbot:latest
ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"

View File

@@ -1,4 +1,4 @@
FROM python:3.10-slim
FROM python:3.11-slim
WORKDIR /AstrBot
COPY . /AstrBot/

View File

@@ -31,13 +31,21 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
> [!NOTE]
> [!WARNING]
>
> 个人微信接入所依赖的开源项目 Gewechat 近期已停止维护,`v3.5.10` 已经支持接入 WeChatPadPro 替换 gewechat 方式。详见文档 [WeChatPadPro](https://astrbot.app/deploy/platform/wechat/wechatpadpro.html)
> 请务必修改默认密码以及保证 AstrBot 版本 >= 3.5.13。
## ✨ 近期更新
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
<details><summary>1. AstrBot 现已自带知识库能力</summary>
📚 详见[文档](https://astrbot.app/use/knowledge-base.html)
![image](https://github.com/user-attachments/assets/28b639b0-bb5c-4958-8e94-92ae8cfd1ab4)
</details>
2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能
@@ -171,7 +179,6 @@ pre-commit install
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
## ✨ Demo

View File

@@ -3,7 +3,6 @@ import tempfile
import httpx
import yaml
import re
from enum import Enum
from io import BytesIO
from pathlib import Path
@@ -59,7 +58,16 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
proxy=proxy if proxy else None, follow_redirects=True
) as client:
resp = client.get(download_url)
resp.raise_for_status()
if (
resp.status_code == 404
and "archive/refs/heads/master.zip" in download_url
):
alt_url = download_url.replace("master.zip", "main.zip")
click.echo("master 分支不存在,尝试下载 main 分支")
resp = client.get(alt_url)
resp.raise_for_status()
else:
resp.raise_for_status()
zip_content = BytesIO(resp.content)
with ZipFile(zip_content) as z:
z.extractall(temp_dir)
@@ -91,39 +99,6 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
return {}
def extract_py_metadata(plugin_dir: Path) -> dict:
"""从 Python 文件中提取插件元数据
Args:
plugin_dir: 插件目录路径
Returns:
dict: 包含元数据的字典,如果提取失败则返回空字典
"""
# 检查 main.py 或与目录同名的 py 文件
for pattern in ["main.py", f"{plugin_dir.name}.py"]:
for py_file in plugin_dir.glob(pattern):
try:
content = py_file.read_text(encoding="utf-8")
register_match = re.search(
r'@register_star\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"?([^")]+)"?)?\s*\)',
content,
)
if register_match:
# 映射匹配组到元数据键
metadata = {}
keys = ["name", "author", "desc", "version", "repo"]
for i, key in enumerate(keys):
if i + 1 <= len(
register_match.groups()
) and register_match.group(i + 1):
metadata[key] = register_match.group(i + 1)
return metadata
except Exception as e:
click.echo(f"读取 {py_file} 失败: {e}", err=True)
return {}
def build_plug_list(plugins_dir: Path) -> list:
"""构建插件列表,包含本地和在线插件信息
@@ -139,31 +114,22 @@ def build_plug_list(plugins_dir: Path) -> list:
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
# 从不同来源加载元数据
# 从 metadata.yaml 加载元数据
metadata = load_yaml_metadata(plugin_dir)
# 如果元数据不完整,尝试从 Python 文件提取
if not metadata or not all(
# 如果成功加载元数据,添加到结果列表
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
py_metadata = extract_py_metadata(plugin_dir)
# 合并元数据,保留已有的值
for key, value in py_metadata.items():
if key not in metadata or not metadata[key]:
metadata[key] = value
# 如果成功提取元数据,添加到结果列表
if metadata:
result.append(
{
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
}
)
result.append({
"name": str(metadata.get("name", "")),
"desc": str(metadata.get("desc", "")),
"version": str(metadata.get("version", "")),
"author": str(metadata.get("author", "")),
"repo": str(metadata.get("repo", "")),
"status": PluginStatus.INSTALLED,
"local_path": str(plugin_dir),
})
# 获取在线插件列表
online_plugins = []
@@ -173,17 +139,15 @@ def build_plug_list(plugins_dir: Path) -> list:
resp.raise_for_status()
data = resp.json()
for plugin_id, plugin_info in data.items():
online_plugins.append(
{
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
}
)
online_plugins.append({
"name": str(plugin_id),
"desc": str(plugin_info.get("desc", "")),
"version": str(plugin_info.get("version", "")),
"author": str(plugin_info.get("author", "")),
"repo": str(plugin_info.get("repo", "")),
"status": PluginStatus.NOT_INSTALLED,
"local_path": None,
})
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)

View File

@@ -43,6 +43,7 @@ class AstrBotConfig(dict):
"""不存在时载入默认配置"""
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
object.__setattr__(self, "first_deploy", True) # 标记第一次部署
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
@@ -82,23 +83,61 @@ class AstrBotConfig(dict):
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
"""检查配置完整性,如果有新的配置项则返回 True"""
"""检查配置完整性,如果有新的配置项或顺序不一致则返回 True"""
has_new = False
# 创建一个新的有序字典以保持参考配置的顺序
new_conf = {}
# 先按照参考配置的顺序添加配置项
for key, value in refer_conf.items():
if key not in conf:
# logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,插入默认值 {value}")
# 配置项不存在,插入默认值
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}")
conf[key] = value
new_conf[key] = value
has_new = True
else:
if conf[key] is None:
conf[key] = value
# 配置项为 None使用默认值
new_conf[key] = value
has_new = True
elif isinstance(value, dict):
has_new |= self.check_config_integrity(
value, conf[key], path + "." + key if path else key
)
# 递归检查子配置项
if not isinstance(conf[key], dict):
# 类型不匹配,使用默认值
new_conf[key] = value
has_new = True
else:
# 递归检查并同步顺序
child_has_new = self.check_config_integrity(
value, conf[key], path + "." + key if path else key
)
new_conf[key] = conf[key]
has_new |= child_has_new
else:
# 直接使用现有配置
new_conf[key] = conf[key]
# 检查是否存在参考配置中没有的配置项
for key in list(conf.keys()):
if key not in refer_conf:
path_ = path + "." + key if path else key
logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除")
has_new = True
# 顺序不一致也算作变更
if list(conf.keys()) != list(new_conf.keys()):
if path:
logger.info(f"检查到配置项 {path} 的子项顺序不一致,已重新排序")
else:
logger.info("检查到配置项顺序不一致,已重新排序")
has_new = True
# 更新原始配置
conf.clear()
conf.update(new_conf)
return has_new
def save_config(self, replace_config: Dict = None):

View File

@@ -5,7 +5,7 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "3.5.10"
VERSION = "3.5.15"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
# 默认配置
@@ -40,12 +40,15 @@ DEFAULT_CONFIG = {
},
"no_permission_reply": True,
"empty_mention_waiting": True,
"empty_mention_waiting_need_reply": True,
"friend_message_needs_wake_prefix": False,
"ignore_bot_self_message": False,
"ignore_at_all": False,
},
"provider": [],
"provider_settings": {
"enable": True,
"default_provider_id": "",
"wake_prefix": "",
"web_search": False,
"web_search_link": False,
@@ -57,6 +60,7 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"streaming_segmented": False,
"separate_provider": False,
},
"provider_stt_settings": {
"enable": False,
@@ -66,6 +70,7 @@ DEFAULT_CONFIG = {
"enable": False,
"provider_id": "",
"dual_output": False,
"use_file_service": False,
},
"provider_ltm_settings": {
"group_icl_enable": False,
@@ -91,6 +96,7 @@ DEFAULT_CONFIG = {
"t2i_word_threshold": 150,
"t2i_strategy": "remote",
"t2i_endpoint": "",
"t2i_use_file_service": False,
"http_proxy": "",
"dashboard": {
"enable": True,
@@ -176,6 +182,7 @@ CONFIG_METADATA_2 = {
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
"callback_server_host": "0.0.0.0",
"port": 6194,
"active_send_mode": False,
},
"wecom(企业微信)": {
"id": "wecom",
@@ -220,20 +227,25 @@ CONFIG_METADATA_2 = {
},
},
"items": {
"active_send_mode": {
"description": "是否换用主动发送接口",
"type": "bool",
"desc": "只有企业认证的公众号才能主动发送。主动发送接口的限制会少一些。",
},
"wpp_active_message_poll": {
"description": "是否启用主动消息轮询",
"type": "bool",
"hint": "只有当你发现微信消息没有按时同步到 AstrBot 时,才需要启用这个功能,默认不启用。"
"description": "是否启用主动消息轮询",
"type": "bool",
"hint": "只有当你发现微信消息没有按时同步到 AstrBot 时,才需要启用这个功能,默认不启用。",
},
"wpp_active_message_poll_interval": {
"description": "主动消息轮询间隔",
"type": "int",
"hint": "主动消息轮询间隔,单位为秒,默认 3 秒,最大不要超过 60 秒,否则可能被认为是旧消息。"
"description": "主动消息轮询间隔",
"type": "int",
"hint": "主动消息轮询间隔,单位为秒,默认 3 秒,最大不要超过 60 秒,否则可能被认为是旧消息。",
},
"kf_name": {
"description": "微信客服账号名",
"type": "string",
"hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取"
"description": "微信客服账号名",
"type": "string",
"hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取",
},
"telegram_token": {
"description": "Bot Token",
@@ -256,10 +268,10 @@ CONFIG_METADATA_2 = {
"hint": "Telegram 命令自动刷新间隔,单位为秒。",
},
"id": {
"description": "ID",
"description": "机器人名称",
"type": "string",
"obvious_hint": True,
"hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突",
"hint": "机器人名称(ID)不能和其它的平台适配器重复。",
},
"type": {
"description": "适配器类型",
@@ -347,9 +359,14 @@ CONFIG_METADATA_2 = {
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
},
"empty_mention_waiting": {
"description": "只 @ 机器人是否触发等待回复",
"description": "只 @ 机器人是否触发等待",
"type": "bool",
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待回复,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
},
"empty_mention_waiting_need_reply": {
"description": "只 @ 机器人触发等待时是否需要回复提醒",
"type": "bool",
"hint": "在上面一个配置项中,如果启用了触发等待,启用此项后,机器人会使用 LLM 生成一条回复。否则,将不回复而只是等待。",
},
"friend_message_needs_wake_prefix": {
"description": "私聊消息是否需要唤醒前缀",
@@ -361,6 +378,11 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
},
"ignore_at_all": {
"description": "是否忽略 @ 全体成员",
"type": "bool",
"hint": "启用后,机器人会忽略 @ 全体成员 的消息事件。",
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
@@ -612,6 +634,7 @@ CONFIG_METADATA_2 = {
"gm_resp_image_modal": False,
"gm_native_search": False,
"gm_native_coderunner": False,
"gm_url_context": False,
"gm_safety_settings": {
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
@@ -818,7 +841,7 @@ CONFIG_METADATA_2 = {
"azure_tts_rate": "1",
"azure_tts_volume": "100",
"azure_tts_subscription_key": "",
"azure_tts_region": "eastus"
"azure_tts_region": "eastus",
},
"MiniMax TTS(API)": {
"id": "minimax_tts",
@@ -841,44 +864,158 @@ CONFIG_METADATA_2 = {
"minimax-voice-english-normalization": False,
"timeout": 20,
},
"火山引擎_TTS(API)": {
"id": "volcengine_tts",
"type": "volcengine_tts",
"provider_type": "text_to_speech",
"enable": False,
"api_key": "",
"appid": "",
"volcengine_cluster": "volcano_tts",
"volcengine_voice_type": "",
"volcengine_speed_ratio": 1.0,
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
"timeout": 20,
},
"OpenAI Embedding": {
"id": "openai_embedding",
"type": "openai_embedding",
"provider_type": "embedding",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
"embedding_model": "",
"embedding_dimensions": 1536,
"timeout": 20,
},
"Gemini Embedding": {
"id": "gemini_embedding",
"type": "gemini_embedding",
"provider_type": "embedding",
"enable": True,
"embedding_api_key": "",
"embedding_api_base": "",
"embedding_model": "gemini-embedding-exp-03-07",
"embedding_dimensions": 768,
"timeout": 20,
},
},
"items": {
"embedding_dimensions": {
"description": "嵌入维度",
"type": "int",
"hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。",
},
"embedding_model": {
"description": "嵌入模型",
"type": "string",
"hint": "嵌入模型名称。",
},
"embedding_api_key": {
"description": "API Key",
"type": "string",
},
"embedding_api_base": {
"description": "API Base URL",
"type": "string",
},
"volcengine_cluster": {
"type": "string",
"description": "火山引擎集群",
"hint": "若使用语音复刻大模型可选volcano_icl或volcano_icl_concurr默认使用volcano_tts",
},
"volcengine_voice_type": {
"type": "string",
"description": "火山引擎音色",
"hint": "输入声音id(Voice_type)",
},
"volcengine_speed_ratio": {
"type": "float",
"description": "语速设置",
"hint": "语速设置,范围为 0.2 到 3.0,默认值为 1.0",
},
"volcengine_volume_ratio": {
"type": "float",
"description": "音量设置",
"hint": "音量设置,范围为 0.0 到 2.0,默认值为 1.0",
},
"azure_tts_voice": {
"type": "string",
"description": "音色设置",
"hint": "API 音色"
"hint": "API 音色",
},
"azure_tts_style": {
"type": "string",
"description": "风格设置",
"hint": "声音特定的讲话风格。 可以表达快乐、同情和平静等情绪。"
"hint": "声音特定的讲话风格。 可以表达快乐、同情和平静等情绪。",
},
"azure_tts_role": {
"type": "string",
"description": "模仿设置(可选)",
"hint": "讲话角色扮演。 声音可以模仿不同的年龄和性别,但声音名称不会更改。 例如,男性语音可以提高音调和改变语调来模拟女性语音,但语音名称不会更改。 如果角色缺失或不受声音的支持,则会忽略此属性。",
"options": ["Boy","Girl","YoungAdultFemale","YoungAdultMale","OlderAdultFemale","OlderAdultMale","SeniorFemale","SeniorMale","禁用"]
"options": [
"Boy",
"Girl",
"YoungAdultFemale",
"YoungAdultMale",
"OlderAdultFemale",
"OlderAdultMale",
"SeniorFemale",
"SeniorMale",
"禁用",
],
},
"azure_tts_rate": {
"type": "string",
"description": "语速设置",
"hint": "指示文本的讲出速率。可在字词或句子层面应用语速。 速率变化应为原始音频的 0.5 到 2 倍。"
"hint": "指示文本的讲出速率。可在字词或句子层面应用语速。 速率变化应为原始音频的 0.5 到 2 倍。",
},
"azure_tts_volume": {
"type": "string",
"description": "语音音量设置",
"hint": "指示语音的音量级别。 可在句子层面应用音量的变化。以从 0.0 到 100.0(从最安静到最大声,例如 75的数字表示。 默认值为 100.0。"
"hint": "指示语音的音量级别。 可在句子层面应用音量的变化。以从 0.0 到 100.0(从最安静到最大声,例如 75的数字表示。 默认值为 100.0。",
},
"azure_tts_region": {
"type": "string",
"description": "API 地区",
"hint": "Azure_TTS 处理数据所在区域,具体参考 https://learn.microsoft.com/zh-cn/azure/ai-services/speech-service/regions",
"options": ["southafricanorth", "eastasia", "southeastasia", "australiaeast", "centralindia", "japaneast", "japanwest", "koreacentral", "canadacentral", "northeurope", "westeurope", "francecentral", "germanywestcentral", "norwayeast", "swedencentral", "switzerlandnorth", "switzerlandwest", "uksouth", "uaenorth", "brazilsouth", "qatarcentral", "centralus", "eastus", "eastus2", "northcentralus", "southcentralus", "westcentralus", "westus", "westus2", "westus3"]
"options": [
"southafricanorth",
"eastasia",
"southeastasia",
"australiaeast",
"centralindia",
"japaneast",
"japanwest",
"koreacentral",
"canadacentral",
"northeurope",
"westeurope",
"francecentral",
"germanywestcentral",
"norwayeast",
"swedencentral",
"switzerlandnorth",
"switzerlandwest",
"uksouth",
"uaenorth",
"brazilsouth",
"qatarcentral",
"centralus",
"eastus",
"eastus2",
"northcentralus",
"southcentralus",
"westcentralus",
"westus",
"westus2",
"westus3",
],
},
"azure_tts_subscription_key": {
"type": "string",
"description": "服务订阅密钥",
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)"
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
},
"dashscope_tts_voice": {
"description": "语音合成模型",
@@ -902,6 +1039,12 @@ CONFIG_METADATA_2 = {
"hint": "启用后所有函数工具将全部失效",
"obvious_hint": True,
},
"gm_url_context": {
"description": "启用URL上下文功能",
"type": "bool",
"hint": "启用后所有函数工具将全部失效",
"obvious_hint": True,
},
"gm_safety_settings": {
"description": "安全过滤器",
"type": "object",
@@ -973,7 +1116,33 @@ CONFIG_METADATA_2 = {
"type": "string",
"description": "指定语言/方言",
"hint": "增强对指定的小语种和方言的识别能力,设置后可以提升在指定小语种/方言场景下的语音表现",
"options": [ "Chinese","Chinese,Yue","English","Arabic","Russian","Spanish","French","Portuguese","German","Turkish","Dutch","Ukrainian","Vietnamese","Indonesian","Japanese","Italian","Korean","Thai","Polish","Romanian","Greek","Czech","Finnish","Hindi","auto",],
"options": [
"Chinese",
"Chinese,Yue",
"English",
"Arabic",
"Russian",
"Spanish",
"French",
"Portuguese",
"German",
"Turkish",
"Dutch",
"Ukrainian",
"Vietnamese",
"Indonesian",
"Japanese",
"Italian",
"Korean",
"Thai",
"Polish",
"Romanian",
"Greek",
"Czech",
"Finnish",
"Hindi",
"auto",
],
},
"minimax-voice-speed": {
"type": "float",
@@ -1010,7 +1179,15 @@ CONFIG_METADATA_2 = {
"type": "string",
"description": "情绪",
"hint": "控制合成语音的情绪",
"options": ["happy","sad","angry","fearful","disgusted","surprised","neutral",],
"options": [
"happy",
"sad",
"angry",
"fearful",
"disgusted",
"surprised",
"neutral",
],
},
"minimax-voice-latex": {
"type": "bool",
@@ -1223,9 +1400,19 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
"obvious_hint": True,
},
"separate_provider": {
"description": "提供商会话隔离",
"type": "bool",
"hint": "启用后每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。",
},
"default_provider_id": {
"description": "默认模型提供商 ID",
"type": "string",
"hint": "可选。每个聊天会话的默认提供商 ID。",
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
"type": "string",
@@ -1338,7 +1525,7 @@ CONFIG_METADATA_2 = {
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID不填则默认第一个STT提供商",
"description": "提供商 ID",
"type": "string",
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
},
@@ -1355,7 +1542,7 @@ CONFIG_METADATA_2 = {
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID不填则默认第一个TTS提供商",
"description": "提供商 ID",
"type": "string",
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
},
@@ -1365,6 +1552,11 @@ CONFIG_METADATA_2 = {
"hint": "启用后Bot 将同时输出语音和文字消息。",
"obvious_hint": True,
},
"use_file_service": {
"description": "使用文件服务提供 TTS 语音文件",
"type": "bool",
"hint": "启用后,如已配置 callback_api_base 将会使用文件服务提供TTS语音文件",
},
},
},
"provider_ltm_settings": {
@@ -1481,7 +1673,7 @@ CONFIG_METADATA_2 = {
"description": "对外可达的回调接口地址",
"type": "string",
"obvious_hint": True,
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址host因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185https://example.com 等。"
"hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址host因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185https://example.com 等。",
},
"log_level": {
"description": "控制台日志级别",
@@ -1500,6 +1692,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
},
"t2i_use_file_service": {
"description": "本地文本转图像使用文件服务提供文件",
"type": "bool",
"hint": "当 t2i_strategy 为 local 并且配置 callback_api_base 时生效。是否使用文件服务提供文件。",
},
"pip_install_arg": {
"description": "pip 安装参数",
"type": "string",

View File

@@ -1,6 +1,6 @@
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
工作流程:
@@ -28,7 +28,6 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
@@ -37,7 +36,7 @@ from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle:
"""
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
@@ -54,7 +53,7 @@ class AstrBotCoreLifecycle:
async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
@@ -73,9 +72,6 @@ class AstrBotCoreLifecycle:
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
# 初始化知识库管理器
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
@@ -87,7 +83,6 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.knowledge_db_manager,
)
# 初始化插件管理器

View File

@@ -1,113 +0,0 @@
import json
import aiosqlite
import os
from typing import Any
from .plugin_storage import PluginStorage
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DBPATH = os.path.join(get_astrbot_data_path(), "plugin_data", "sqlite", "plugin_data.db")
class SQLitePluginStorage(PluginStorage):
"""插件数据的 SQLite 存储实现类。
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
所有数据以 (plugin, key) 作为复合主键进行索引。
"""
_instance = None # Standalone instance of the class
_db_conn = None
db_path = None
def __new__(cls):
"""
创建或获取 SQLitePluginStorage 的单例实例。
如果实例已存在,则返回现有实例;否则创建一个新实例。
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
"""
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
if cls._instance is None:
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
cls._instance.db_path = DBPATH
return cls._instance
async def _init_db(self):
"""初始化数据库连接(只执行一次)"""
if SQLitePluginStorage._db_conn is None:
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
await self._setup_db()
async def _setup_db(self):
"""
异步初始化数据库。
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
其中 plugin 和 key 组合作为主键。
"""
await self._db_conn.execute("""
CREATE TABLE IF NOT EXISTS plugin_data (
plugin TEXT,
key TEXT,
value TEXT,
PRIMARY KEY (plugin, key)
)
""")
await self._db_conn.commit()
async def set(self, plugin: str, key: str, value: Any):
"""
异步存储数据。
将指定插件的键值对存入数据库,如果键已存在则更新值。
值会被序列化为 JSON 字符串后存储。
Args:
plugin: 插件标识符
key: 数据键名
value: 要存储的数据值(任意类型,将被 JSON 序列化)
"""
await self._init_db()
await self._db_conn.execute(
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
(plugin, key, json.dumps(value)),
)
await self._db_conn.commit()
async def get(self, plugin: str, key: str) -> Any:
"""
异步获取数据。
从数据库中获取指定插件和键名对应的值,
返回的值会从 JSON 字符串反序列化为原始数据类型。
Args:
plugin: 插件标识符
key: 数据键名
Returns:
Any: 存储的数据值,如果未找到则返回 None
"""
await self._init_db()
async with self._db_conn.execute(
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
(plugin, key),
) as cursor:
row = await cursor.fetchone()
return json.loads(row[0]) if row else None
async def delete(self, plugin: str, key: str):
"""
异步删除数据。
从数据库中删除指定插件和键名对应的数据项。
Args:
plugin: 插件标识符
key: 要删除的数据键名
"""
await self._init_db()
await self._db_conn.execute(
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
)
await self._db_conn.commit()

View File

@@ -11,7 +11,9 @@ class SQLiteDatabase(BaseDatabase):
super().__init__()
self.db_path = db_path
with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
with open(
os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8"
) as f:
sql = f.read()
# 初始化数据库

View File

@@ -0,0 +1,46 @@
import abc
from dataclasses import dataclass
@dataclass
class Result:
similarity: float
data: dict
class BaseVecDB:
async def initialize(self):
"""
初始化向量数据库
"""
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
top_k (int): 返回的最相似文档的数量
Returns:
List[Result]: 查询结果
"""
...
@abc.abstractmethod
async def delete(self, doc_id: str) -> bool:
"""
删除指定文档。
Args:
doc_id (str): 要删除的文档 ID
Returns:
bool: 删除是否成功
"""
...

View File

@@ -0,0 +1,3 @@
from .vec_db import FaissVecDB
__all__ = ["FaissVecDB"]

View File

@@ -0,0 +1,121 @@
import aiosqlite
import os
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
async def get_documents(self, metadata_filters: dict, ids: list = None):
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
"""
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
await self.connection.commit()
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
Returns:
list: A list of user IDs.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
return [row[0] for row in rows]
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
Args:
row (tuple): The row to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": row[0],
"doc_id": row[1],
"text": row[2],
"metadata": row[3],
"created_at": row[4],
"updated_at": row[5],
}
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None

View File

@@ -0,0 +1,59 @@
try:
import faiss
except ModuleNotFoundError:
raise ImportError(
"faiss 未安装。请使用 'pip install faiss-cpu''pip install faiss-gpu' 安装。"
)
import os
import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str = None):
self.dimension = dimension
self.path = path
self.index = None
if path and os.path.exists(path):
self.index = faiss.read_index(path)
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
Args:
vector (np.ndarray): 要插入的向量
id (int): 向量的ID
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.storage[id] = vector
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
"""搜索最相似的向量
Args:
vector (np.ndarray): 查询向量
k (int): 返回的最相似向量的数量
Returns:
tuple: (距离, 索引)
"""
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def save_index(self):
"""保存索引
Args:
path (str): 保存索引的路径
"""
faiss.write_index(self.index, self.path)

View File

@@ -0,0 +1,17 @@
-- 创建文档存储表,包含 faiss 中文档的 id文档文本create_atupdated_at
CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id TEXT NOT NULL,
text TEXT NOT NULL,
metadata TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE documents
ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
ALTER TABLE documents
ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
CREATE INDEX idx_documents_user_id ON documents(user_id);
CREATE INDEX idx_documents_group_id ON documents(group_id);

View File

@@ -0,0 +1,117 @@
import uuid
import json
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
class FaissVecDB(BaseVecDB):
"""
A class to represent a vector database.
"""
def __init__(
self,
doc_store_path: str,
index_store_path: str,
embedding_provider: EmbeddingProvider,
):
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
self.embedding_provider = embedding_provider
self.document_storage = DocumentStorage(doc_store_path)
self.embedding_storage = EmbeddingStorage(
embedding_provider.get_dim(), index_store_path
)
self.embedding_provider = embedding_provider
async def initialize(self):
await self.document_storage.initialize()
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
metadata = metadata or {}
str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def retrieve(
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
) -> list[Result]:
"""
搜索最相似的文档。
Args:
query (str): 查询文本
k (int): 返回的最相似文档的数量
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
metadata_filters (dict): 元数据过滤器
Returns:
List[Result]: 查询结果
"""
embedding = await self.embedding_provider.get_embedding(query)
scores, indices = await self.embedding_storage.search(
vector=np.array([embedding]).astype("float32"),
k=fetch_k if metadata_filters else k,
)
# TODO: rerank
if len(indices[0]) == 0 or indices[0][0] == -1:
return []
# normalize scores
scores[0] = 1.0 - (scores[0] / 2.0)
# NOTE: maybe the size is less than k.
fetched_docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters or {}, ids=indices[0]
)
if not fetched_docs:
return []
result_docs = []
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
for i, indice_idx in enumerate(indices[0]):
pos = idx_pos.get(indice_idx)
if pos is None:
continue
fetch_doc = fetched_docs[pos]
score = scores[0][i]
result_docs.append(Result(similarity=float(score), data=fetch_doc))
return result_docs[:k]
async def delete(self, doc_id: int):
"""
删除一条文档
"""
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
async def close(self):
await self.document_storage.close()
async def count_documents(self) -> int:
"""
计算文档数量
"""
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0

View File

@@ -26,13 +26,14 @@ class InitialLoader:
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
core_task = []
try:
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
except Exception as e:
logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
return
core_task = core_lifecycle.start()
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event

View File

@@ -102,6 +102,10 @@ class BaseMessageComponent(BaseModel):
data[k] = v
return {"type": self.type.lower(), "data": data}
async def to_dict(self) -> dict:
# 默认情况下,回退到旧的同步 toDict()
return self.toDict()
class Plain(BaseMessageComponent):
type: ComponentType = "Plain"
@@ -118,6 +122,9 @@ class Plain(BaseMessageComponent):
self.text.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
)
def toDict(self):
return {"type": "text", "data": {"text": self.text.strip()}}
class Face(BaseMessageComponent):
type: ComponentType = "Face"
@@ -235,9 +242,6 @@ class Video(BaseMessageComponent):
path: T.Optional[str] = ""
def __init__(self, file: str, **_):
# for k in _.keys():
# if k == "c" and _[k] not in [2, 3]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -250,6 +254,70 @@ class Video(BaseMessageComponent):
return Video(file=url, **_)
raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL则会自动进行下载
Returns:
str: 视频的本地路径,以绝对路径表示。
"""
url = self.file
if url and url.startswith("file:///"):
return url[8:]
elif url and url.startswith("http"):
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
else:
raise Exception(f"download failed: {url}")
elif os.path.exists(url):
return os.path.abspath(url)
else:
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
"""
将视频注册到文件服务。
Returns:
str: 注册后的URL
Raises:
Exception: 如果未配置 callback_api_base
"""
callback_host = astrbot_config.get("callback_api_base")
if not callback_host:
raise Exception("未配置 callback_api_base文件服务不可用")
file_path = await self.convert_to_file_path()
token = await file_token_service.register_file(file_path)
logger.debug(f"已注册:{callback_host}/api/file/{token}")
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = self.file
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated video file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "video",
"data": {
"file": payload_file,
},
}
class At(BaseMessageComponent):
type: ComponentType = "At"
@@ -259,6 +327,12 @@ class At(BaseMessageComponent):
def __init__(self, **_):
super().__init__(**_)
def toDict(self):
return {
"type": "at",
"data": {"qq": str(self.qq)},
}
class AtAll(At):
qq: str = "all"
@@ -514,27 +588,47 @@ class Node(BaseMessageComponent):
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[str] = "0" # qq号
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
content: T.Optional[list[BaseMessageComponent]] = []
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0 # 忽略
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
if isinstance(content, list):
_content = None
if all(isinstance(item, Node) for item in content):
_content = [node.toDict() for node in content]
else:
_content = ""
for chain in content:
_content += chain.toString()
content = _content
elif isinstance(content, Node):
content = content.toDict()
def __init__(self, content: list[BaseMessageComponent], **_):
if isinstance(content, Node):
# back
content = [content]
super().__init__(content=content, **_)
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
async def to_dict(self):
data_content = []
for comp in self.content:
if isinstance(comp, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await comp.convert_to_base64()
data_content.append(
{
"type": comp.type.lower(),
"data": {"file": f"base64://{bs64}"},
}
)
elif isinstance(comp, File):
# For File segments, we need to handle the file differently
d = await comp.to_dict()
data_content.append(d)
elif isinstance(comp, (Node, Nodes)):
# For Node segments, we recursively convert them to dict
d = await comp.to_dict()
data_content.append(d)
else:
d = comp.toDict()
data_content.append(d)
return {
"type": "node",
"data": {
"user_id": str(self.uin),
"nickname": self.name,
"content": data_content,
},
}
class Nodes(BaseMessageComponent):
@@ -545,12 +639,20 @@ class Nodes(BaseMessageComponent):
super().__init__(nodes=nodes, **_)
def toDict(self):
"""Deprecated. Use to_dict instead"""
ret = {
"messages": [],
}
for node in self.nodes:
d = node.toDict()
d["data"]["uin"] = str(node.uin) # 转为字符串
ret["messages"].append(d)
return ret
async def to_dict(self):
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
d = await node.to_dict()
ret["messages"].append(d)
return ret
@@ -723,6 +825,26 @@ class File(BaseMessageComponent):
return f"{callback_host}/api/file/{token}"
async def to_dict(self):
"""需要和 toDict 区分开toDict 是同步方法"""
url_or_path = await self.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
return {
"type": "file",
"data": {
"name": self.name,
"file": payload_file,
},
}
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"

View File

@@ -43,31 +43,31 @@ class PreProcessStage(Stage):
# STT
if self.stt_settings.get("enable", False):
# TODO: 独立
stt_provider = (
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
)
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
ctx = self.plugin_manager.context
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
if not stt_provider:
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break

View File

@@ -33,6 +33,7 @@ from mcp.types import (
TextResourceContents,
BlobResourceContents,
)
from astrbot.core import web_chat_back_queue
class LLMRequestSubStage(Stage):
@@ -67,7 +68,11 @@ class LLMRequestSubStage(Stage):
) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
umo = event.unified_msg_origin
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
if provider is None:
return
@@ -283,7 +288,66 @@ class LLMRequestSubStage(Stage):
if img_b64 := event.get_extra("tool_call_img_respond"):
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
event.set_extra("tool_call_img_respond", None)
yield
if event.get_platform_name() == "webchat":
# 异步处理 WebChat 特殊情况
asyncio.create_task(self._handle_webchat(event, req))
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
if conversation and not req.conversation.title:
messages = json.loads(conversation.history)
latest_pair = messages[-2:]
if not latest_pair:
return
provider = self.ctx.plugin_manager.context.get_using_provider()
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
# if len(latest_pair) > 1:
# cleaned_text += (
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
# )
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await provider.text_chat(
system_prompt="You are expert in summarizing user's query.",
prompt=(
f"Please summarize the following query of user:\n"
f"{cleaned_text}\n"
"Only output the summary within 10 words, DO NOT INCLUDE any other text."
"You must use the same language as the user."
"If you think the dialog is too short to summarize, only output a special mark: `None`"
),
)
if llm_resp and llm_resp.completion_text:
logger.debug(
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
)
title = llm_resp.completion_text.strip()
if not title or "None" == title:
return
await self.conv_manager.update_conversation_title(
event.unified_msg_origin, title=title
)
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
# webchat adapter 中session_id 的格式是 f"webchat!{username}!{cid}"
# TODO: 优化 WebChat 适配器的对话管理
if event.session_id:
username, cid = event.session_id.split("!")[1:3]
db_helper = self.ctx.plugin_manager.context._db
db_helper.update_conversation_title(
user_id=username,
cid=cid,
title=title,
)
web_chat_back_queue.put_nowait(
{
"type": "update_title",
"cid": cid,
"data": title,
}
)
async def _handle_llm_response(
self,

View File

@@ -29,11 +29,10 @@ class RespondStage(Stage):
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Node: lambda comp: bool(comp.name)
and comp.uin != 0
and bool(comp.content), # 一个转发节点
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
}
async def initialize(self, ctx: PipelineContext):
@@ -192,6 +191,7 @@ class RespondStage(Stage):
await asyncio.sleep(i)
try:
await event.send(MessageChain([*decorated_comps, comp]))
decorated_comps = [] # 清空已发送的装饰组件
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break

View File

@@ -1,17 +1,18 @@
import time
import re
import time
import traceback
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from typing import AsyncGenerator, Union
from astrbot.core import html_renderer, logger, file_token_service
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage, registered_stages
@register_stage
@@ -168,30 +169,55 @@ class ResultDecorateStage(Stage):
result.chain = new_chain
# TTS
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
event.unified_msg_origin
)
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
):
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + comp.text)
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(
Record(file=audio_path, url=audio_path)
)
if(self.ctx.astrbot_config["provider_tts_settings"]["dual_output"]):
new_chain.append(comp)
else:
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件找到,消息段转语音失败: {comp.text}"
f"由于 TTS 音频文件找到,消息段转语音失败: {comp.text}"
)
new_chain.append(comp)
except BaseException:
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
)
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
@@ -225,6 +251,14 @@ class ResultDecorateStage(Stage):
if url:
if url.startswith("http"):
result.chain = [Image.fromURL(url)]
elif (
self.ctx.astrbot_config["t2i_use_file_service"]
and self.ctx.astrbot_config["callback_api_base"]
):
token = await file_token_service.register_file(url)
url = f"{self.ctx.astrbot_config['callback_api_base']}/api/file/{token}"
logger.debug(f"已注册:{url}")
result.chain = [Image.fromURL(url)]
else:
result.chain = [Image.fromFileSystem(url)]

View File

@@ -4,7 +4,7 @@ from astrbot import logger
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At
from astrbot.core.message.components import At, AtAll
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.star.filter.permission import PermissionTypeFilter
@@ -39,6 +39,9 @@ class WakingCheckStage(Stage):
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
"ignore_bot_self_message", False
)
self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get(
"ignore_at_all", False
)
async def process(
self, event: AstrMessageEvent
@@ -79,10 +82,9 @@ class WakingCheckStage(Stage):
if not is_wake:
# 检查是否有 at 消息
for message in messages:
if isinstance(message, At) and (
if (isinstance(message, At) and (
str(message.qq) == str(event.get_self_id())
or str(message.qq) == "all"
):
)) or (isinstance(message, AtAll) and not self.ignore_at_all):
is_wake = True
event.is_wake = True
wake_prefix = ""

View File

@@ -3,9 +3,17 @@ import re
from typing import AsyncGenerator, Dict, List
from aiocqhttp import CQHttp
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record, File
from astrbot.api.message_components import (
Image,
Node,
Nodes,
Plain,
Record,
Video,
File,
BaseMessageComponent,
)
from astrbot.api.platform import Group, MessageMember
from astrbot.core import file_token_service, astrbot_config, logger
class AiocqhttpMessageEvent(AstrMessageEvent):
@@ -15,28 +23,38 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@staticmethod
async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict:
"""修复部分字段"""
if isinstance(segment, (Image, Record)):
# For Image and Record segments, we convert them to base64
bs64 = await segment.convert_to_base64()
return {
"type": segment.type.lower(),
"data": {
"file": f"base64://{bs64}",
},
}
elif isinstance(segment, File):
# For File segments, we need to handle the file differently
d = await segment.to_dict()
return d
elif isinstance(segment, Video):
d = await segment.to_dict()
return d
else:
# For other segments, we simply convert them to a dict by calling toDict
return segment.toDict()
@staticmethod
async def _parse_onebot_json(message_chain: MessageChain):
"""解析成 OneBot json 格式"""
ret = []
for segment in message_chain.chain:
d = segment.toDict()
if isinstance(segment, Plain):
d["type"] = "text"
d["data"]["text"] = segment.text.strip()
# 如果是空文本或者只带换行符的文本,不发送
if not d["data"]["text"]:
if not segment.text.strip():
continue
elif isinstance(segment, (Image, Record)):
# convert to base64
bs64 = await segment.convert_to_base64()
d["data"] = {
"file": f"base64://{bs64}",
}
elif isinstance(segment, At):
d["data"] = {
"qq": str(segment.qq), # 转换为字符串
}
d = await AiocqhttpMessageEvent._from_segment_to_dict(segment)
ret.append(d)
return ret
@@ -54,7 +72,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
nodes = Nodes([seg])
seg = nodes
payload = seg.toDict()
payload = await seg.to_dict()
if self.get_group_id():
payload["group_id"] = self.get_group_id()
await self.bot.call_action("send_group_forward_msg", **payload)
@@ -64,21 +83,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
"send_private_forward_msg", **payload
)
elif isinstance(seg, File):
d = seg.toDict()
url_or_path = await seg.get_file(allow_return_url=True)
if url_or_path.startswith("http"):
payload_file = url_or_path
elif callback_host := astrbot_config.get("callback_api_base"):
callback_host = str(callback_host).removesuffix("/")
token = await file_token_service.register_file(url_or_path)
payload_file = f"{callback_host}/api/file/{token}"
logger.debug(f"Generated file callback link: {payload_file}")
else:
payload_file = url_or_path
d["data"] = {
"name": seg.name,
"file": payload_file,
}
d = await AiocqhttpMessageEvent._from_segment_to_dict(seg)
await self.bot.send(
self.message_obj.raw_message,
[d],

View File

@@ -221,6 +221,9 @@ class AiocqhttpAdapter(Platform):
a = None
if t == "text":
current_text = "".join(m["data"]["text"] for m in m_group).strip()
if not current_text:
# 如果文本段为空,则跳过
continue
message_str += current_text
a = ComponentTypes[t](text=current_text) # noqa: F405
abm.message.append(a)

View File

@@ -144,8 +144,8 @@ class TelegramPlatformAdapter(Platform):
command_dict = {}
skip_commands = {"start"}
for handler_md in star_handlers_registry._handlers:
handler_metadata = handler_md[1]
for handler_md in star_handlers_registry:
handler_metadata = handler_md
if not star_map[handler_metadata.handler_module_path].activated:
continue
for event_filter in handler_metadata.event_filters:

View File

@@ -1,4 +1,5 @@
import os
import re
import asyncio
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -18,6 +19,16 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class TelegramPlatformEvent(AstrMessageEvent):
# Telegram 的最大消息长度限制
MAX_MESSAGE_LENGTH = 4096
SPLIT_PATTERNS = {
"paragraph": re.compile(r"\n\n"),
"line": re.compile(r"\n"),
"sentence": re.compile(r"[.!?。!?]"),
"word": re.compile(r"\s"),
}
def __init__(
self,
message_str: str,
@@ -29,8 +40,33 @@ class TelegramPlatformEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(client: ExtBot, message: MessageChain, user_name: str):
def _split_message(self, text: str) -> list[str]:
if len(text) <= self.MAX_MESSAGE_LENGTH:
return [text]
chunks = []
while text:
if len(text) <= self.MAX_MESSAGE_LENGTH:
chunks.append(text)
break
split_point = self.MAX_MESSAGE_LENGTH
segment = text[: self.MAX_MESSAGE_LENGTH]
for _, pattern in self.SPLIT_PATTERNS.items():
if matches := list(pattern.finditer(segment)):
last_match = matches[-1]
split_point = last_match.end()
break
chunks.append(text[:split_point])
text = text[split_point:].lstrip()
return chunks
async def send_with_client(
self, client: ExtBot, message: MessageChain, user_name: str
):
image_path = None
has_reply = False
@@ -59,19 +95,22 @@ class TelegramPlatformEvent(AstrMessageEvent):
if isinstance(i, Plain):
if at_user_id and not at_flag:
i.text = f"@{at_user_id} " + i.text
i.text = f"@{at_user_id} {i.text}"
at_flag = True
text = i.text
try:
text = telegramify_markdown.markdownify(
i.text, max_line_length=None, normalize_whitespace=False
)
except Exception as e:
logger.warning(
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
)
return
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
chunks = self._split_message(i.text)
for chunk in chunks:
try:
md_text = telegramify_markdown.markdownify(
chunk, max_line_length=None, normalize_whitespace=False
)
await client.send_message(
text=md_text, parse_mode="MarkdownV2", **payload
)
except Exception as e:
logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead."
)
await client.send_message(text=chunk, **payload)
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload)
@@ -147,17 +186,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
continue
# Plain
if not message_id:
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
else:
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
@@ -176,6 +205,18 @@ class TelegramPlatformEvent(AstrMessageEvent):
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
else:
# delta 长度一般不会大于 4096因此这里直接发送
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
delta = ""
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
try:
if delta and current_content != delta:

View File

@@ -1,14 +1,15 @@
import asyncio
import base64
import json
import os
import time
from typing import Optional
import aiohttp
import anyio
import websockets
from astrbot import logger
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api.platform import Platform, PlatformMetadata
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astrbot_message import (
@@ -22,6 +23,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .wechatpadpro_message_event import WeChatPadProMessageEvent
try:
from .xml_data_parser import GeweDataParser
except ImportError as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
@register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
class WeChatPadProAdapter(Platform):
@@ -59,6 +67,18 @@ class WeChatPadProAdapter(Platform):
) # 持久化文件路径
self.ws_handle_task = None
# 添加图片消息缓存,用于引用消息处理
self.cached_images = {}
"""缓存图片消息。key是NewMsgId (对应引用消息的svrid)value是图片的base64数据"""
# 设置缓存大小限制,避免内存占用过大
self.max_image_cache = 50
# 添加文本消息缓存,用于引用消息处理
self.cached_texts = {}
"""缓存文本消息。key是NewMsgId (对应引用消息的svrid)value是消息文本内容"""
# 设置文本缓存大小限制
self.max_text_cache = 100
async def run(self) -> None:
"""
启动平台适配器的运行实例。
@@ -69,39 +89,42 @@ class WeChatPadProAdapter(Platform):
self.auth_key = loaded_credentials.get("auth_key")
self.wxid = loaded_credentials.get("wxid")
isLoginIn = await self.check_online_status()
# 检查在线状态
if self.auth_key and await self.check_online_status():
logger.info("WeChatPadPro 设备已在线,跳过扫码登录。")
if self.auth_key and isLoginIn:
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
# 如果在线,连接 WebSocket 接收消息
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
else:
logger.info("WeChatPadPro 设备不在线或无可用凭据,开始扫码登录流程。")
# 1. 生成授权码
await self.generate_auth_key()
if not self.auth_key:
logger.error("无法获取授权码,WeChatPadPro 适配器启动失败")
return
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码")
await self.generate_auth_key()
# 2. 获取登录二维码
qr_code_url = await self.get_login_qr_code()
if not isLoginIn:
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
qr_code_url = await self.get_login_qr_code()
if qr_code_url:
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
else:
logger.error("无法获取登录二维码。")
return
if qr_code_url:
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
else:
logger.error("无法获取登录二维码。")
return
# 3. 检测扫码状态
login_successful = await self.check_login_status()
# 3. 检测扫码状态
login_successful = await self.check_login_status()
if login_successful:
# 登录成功后,连接 WebSocket 接收消息
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
else:
logger.warning("登录失败或超时WeChatPadPro 适配器将关闭。")
await self.terminate()
return
if login_successful:
logger.info("登录成功WeChatPadPro适配器已连接。")
else:
logger.warning("登录失败或超时WeChatPadPro 适配器将关闭。")
await self.terminate()
return
# 登录成功后,连接 WebSocket 接收消息
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
self._shutdown_event = asyncio.Event()
await self._shutdown_event.wait()
@@ -156,16 +179,23 @@ class WeChatPadProAdapter(Platform):
if login_state == 1:
logger.info("WeChatPadPro 设备当前在线。")
return True
else:
logger.info(
f"WeChatPadPro 设备不在线,登录状态: {login_state}"
)
# login_state == 3 为离线状态
elif login_state == 3:
logger.info("WeChatPadPro 设备不在线")
return False
else:
logger.error(f"未知的在线状态: {login_state:}")
return False
# Code == 300 为微信退出状态。
elif response.status == 200 and response_data.get("Code") == 300:
logger.info("WeChatPadPro 设备已退出。")
return False
else:
logger.error(
f"检查在线状态失败: {response.status}, {response_data}"
)
return False
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return False
@@ -179,7 +209,7 @@ class WeChatPadProAdapter(Platform):
"""
url = f"{self.base_url}/admin/GenAuthKey1"
params = {"key": self.admin_key}
payload = {"Count": 1, "Days": 30} # 生成一个有效期30天的授权码
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
async with aiohttp.ClientSession() as session:
try:
@@ -336,12 +366,10 @@ class WeChatPadProAdapter(Platform):
message = await asyncio.wait_for(
websocket.recv(), timeout=wait_time
)
logger.info(message)
# logger.debug(message) # 不显示原始消息内容
asyncio.create_task(self.handle_websocket_message(message))
except asyncio.TimeoutError:
logger.warning(
f"WebSocket 连接空闲超过 {wait_time} s"
)
logger.warning(f"WebSocket 连接空闲超过 {wait_time} s")
break
except websockets.exceptions.ConnectionClosedOK:
logger.info("WebSocket 连接正常关闭。")
@@ -350,7 +378,9 @@ class WeChatPadProAdapter(Platform):
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
break
except Exception as e:
logger.error(f"WebSocket 连接失败: {e}")
logger.error(
f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态或尝试重启WeChatPadPro适配器。"
)
await asyncio.sleep(5)
async def handle_websocket_message(self, message: str):
@@ -443,6 +473,7 @@ class WeChatPadProAdapter(Platform):
"""
if from_user_name == "weixin":
return False
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = from_user_name
@@ -464,6 +495,14 @@ class WeChatPadProAdapter(Platform):
abm.session_id = f"{from_user_name}_{to_user_name}"
else:
abm.session_id = from_user_name
msg_source = raw_message.get("msg_source", "")
if self.wxid in msg_source:
at_me = True
if "在群聊中@了你" in raw_message.get("push_content", ""):
at_me = True
if at_me:
abm.message.insert(0, At(qq=abm.self_id, name=""))
else:
abm.type = MessageType.FRIEND_MESSAGE
abm.group_id = ""
@@ -544,6 +583,32 @@ class WeChatPadProAdapter(Platform):
logger.error(f"下载图片时发生错误: {e}")
return None
async def download_voice(
self, to_user_name: str, new_msg_id: str, bufid: str, length: int
):
"""下载原始音频。"""
url = f"{self.base_url}/message/GetMsgVoice"
params = {"key": self.auth_key}
payload = {
"Bufid": bufid,
"ToUserName": to_user_name,
"NewMsgId": new_msg_id,
"Length": length,
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status == 200:
return await response.json()
logger.error(f"下载音频失败: {response.status}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"下载音频时发生错误: {e}")
return None
async def _process_message_content(
self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str
):
@@ -555,12 +620,69 @@ class WeChatPadProAdapter(Platform):
if abm.type == MessageType.GROUP_MESSAGE:
parts = content.split(":\n", 1)
if len(parts) == 2:
abm.message_str = parts[1]
abm.message.append(Plain(abm.message_str))
message_content = parts[1]
abm.message_str = message_content
# 检查是否@了机器人,参考 gewechat 的实现方式
# 微信大部分客户端在@用户昵称后面,紧接着是一个\u2005字符四分之一空格
at_me = False
# 检查 msg_source 中是否包含机器人的 wxid
# wechatpadpro 的格式: <atuserlist>wxid</atuserlist>
# gewechat 的格式: <atuserlist><![CDATA[wxid]]></atuserlist>
msg_source = raw_message.get("msg_source", "")
if f"<atuserlist>{abm.self_id}</atuserlist>" in msg_source or f"<atuserlist>{abm.self_id}," in msg_source or f",{abm.self_id}</atuserlist>" in msg_source:
at_me = True
# 也检查 push_content 中是否有@提示
push_content = raw_message.get("push_content", "")
if "在群聊中@了你" in push_content:
at_me = True
if at_me:
# 被@了在消息开头插入At组件参考gewechat的做法
bot_nickname = await self._get_group_member_nickname(abm.group_id, abm.self_id)
abm.message.insert(0, At(qq=abm.self_id, name=bot_nickname or abm.self_id))
# 只有当消息内容不仅仅是@时才添加Plain组件
if "\u2005" in message_content:
# 检查@之后是否还有其他内容
parts = message_content.split("\u2005")
if len(parts) > 1 and any(part.strip() for part in parts[1:]):
abm.message.append(Plain(message_content))
else:
# 检查是否只包含@机器人
is_pure_at = False
if bot_nickname and message_content.strip() == f"@{bot_nickname}":
is_pure_at = True
if not is_pure_at:
abm.message.append(Plain(message_content))
else:
# 没有@机器人,作为普通文本处理
abm.message.append(Plain(message_content))
else:
abm.message.append(Plain(abm.message_str))
else: # 私聊消息
abm.message.append(Plain(abm.message_str))
# 缓存文本消息,以便引用消息可以查找
try:
# 获取msg_id作为缓存的key
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id:
# 限制缓存大小
if (
len(self.cached_texts) >= self.max_text_cache
and self.cached_texts
):
# 删除最早的一条缓存
oldest_key = next(iter(self.cached_texts))
self.cached_texts.pop(oldest_key)
logger.debug(f"缓存文本消息new_msg_id={new_msg_id}")
self.cached_texts[str(new_msg_id)] = content
except Exception as e:
logger.error(f"缓存文本消息失败: {e}")
elif msg_type == 3:
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
@@ -574,15 +696,87 @@ class WeChatPadProAdapter(Platform):
)
if image_bs64_data:
abm.message.append(Image.fromBase64(image_bs64_data))
# 缓存图片,以便引用消息可以查找
try:
# 获取msg_id作为缓存的key
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id:
# 限制缓存大小
if (
len(self.cached_images) >= self.max_image_cache
and self.cached_images
):
# 删除最早的一条缓存
oldest_key = next(iter(self.cached_images))
self.cached_images.pop(oldest_key)
logger.debug(f"缓存图片消息new_msg_id={new_msg_id}")
self.cached_images[str(new_msg_id)] = image_bs64_data
except Exception as e:
logger.error(f"缓存图片消息失败: {e}")
elif msg_type == 47:
# 视频消息 (注意:表情消息也是 47需要区分)
logger.warning("收到视频消息,待实现。")
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
raw_message=raw_message,
)
emoji_message = data_parser.parse_emoji()
if emoji_message is not None:
abm.message.append(emoji_message)
elif msg_type == 50:
# 语音/视频
logger.warning("收到语音/视频消息,待实现。")
elif msg_type == 34:
# 语音消息
bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id")
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
raw_message=raw_message,
)
voicemsg = data_parser._format_to_xml().find("voicemsg")
bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice(
to_user_name=to_user_name,
new_msg_id=new_msg_id,
bufid=bufid,
length=length,
)
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir, f"wechatpadpro_voice_{abm.message_id}.silk"
)
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_bs64_data)
abm.message.append(Record(file=file_path, url=file_path))
elif msg_type == 49:
# 引用消息
logger.warning("收到引用消息,待实现。")
try:
parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
cached_texts=self.cached_texts,
cached_images=self.cached_images,
raw_message=raw_message,
downloader=self._download_raw_image,
)
components = await parser.parse_mutil_49()
if components:
abm.message.extend(components)
abm.message_str = "\n".join(
c.text for c in components if isinstance(c, Plain)
)
except Exception as e:
logger.warning(f"msg_type 49 处理失败: {e}")
abm.message.append(Plain("[XML 消息处理失败]"))
abm.message_str = "[XML 消息处理失败]"
else:
logger.warning(f"收到未处理的消息类型: {msg_type}")
@@ -627,3 +821,67 @@ class WeChatPadProAdapter(Platform):
)
# 调用实例方法 send
await sending_event.send(message_chain)
async def get_contact_list(self):
"""
获取联系人列表。
"""
url = f"{self.base_url}/friend/GetContactList"
params = {"key": self.auth_key}
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = (
result.get("Data", {})
.get("ContactList", {})
.get("contactUsernameList", [])
)
return contact_list
else:
logger.error(f"获取联系人列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人列表时发生错误: {e}")
return None
async def get_contact_details_list(
self, room_wx_id_list: list[str] = None, user_names: list[str] = None
) -> Optional[dict]:
"""
获取联系人详情列表。
"""
if room_wx_id_list is None:
room_wx_id_list = []
if user_names is None:
user_names = []
url = f"{self.base_url}/friend/GetContactDetailsList"
params = {"key": self.auth_key}
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人详情列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = result.get("Data", {}).get("contactList", {})
return contact_list
else:
logger.error(f"获取联系人详情列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人详情列表时发生错误: {e}")
return None

View File

@@ -7,11 +7,17 @@ import aiohttp
from PIL import Image as PILImage # 使用别名避免冲突
from astrbot import logger
from astrbot.core.message.components import Image, Plain # Import Image
from astrbot.core.message.components import (
Image,
Plain,
WechatEmoji,
Record,
) # Import Image
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64
if TYPE_CHECKING:
from .wechatpadpro_adapter import WeChatPadProAdapter
@@ -38,6 +44,10 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
await self._send_text(session, comp.text)
elif isinstance(comp, Image):
await self._send_image(session, comp)
elif isinstance(comp, WechatEmoji):
await self._send_emoji(session, comp)
elif isinstance(comp, Record):
await self._send_voice(session, comp)
await super().send(message)
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
@@ -73,12 +83,42 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
message_text = text
payload = {
"MsgItem": [
{"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id}
{
"MsgType": 1,
"TextContent": message_text,
"ToUserName": self.session_id,
}
]
}
url = f"{self.adapter.base_url}/message/SendTextMessage"
await self._post(session, url, payload)
async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji):
payload = {
"EmojiList": [
{
"EmojiMd5": comp.md5,
"EmojiSize": comp.md5_len,
"ToUserName": self.session_id,
}
]
}
url = f"{self.adapter.base_url}/message/SendEmojiMessage"
await self._post(session, url, payload)
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
record_path = await comp.convert_to_file_path()
# 默认已经存在 data/temp 中
b64, duration = await wav_to_tencent_silk_base64(record_path)
payload = {
"ToUserName": self.session_id,
"VoiceData": b64,
"VoiceFormat": 4,
"VoiceSecond": duration,
}
url = f"{self.adapter.base_url}/message/SendVoice"
await self._post(session, url, payload)
@staticmethod
def _validate_base64(b64: str) -> bytes:
return base64.b64decode(b64, validate=True)

View File

@@ -0,0 +1,160 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import (
WechatEmoji as Emoji,
Plain,
Image,
BaseMessageComponent,
)
class GeweDataParser:
def __init__(
self,
content: str,
is_private_chat: bool = False,
cached_texts=None,
cached_images=None,
raw_message: dict = None,
downloader=None,
):
self._xml = None
self.content = content
self.is_private_chat = is_private_chat
self.cached_texts = cached_texts or {}
self.cached_images = cached_images or {}
self.downloader = downloader
raw_message = raw_message or {}
self.from_user_name = raw_message.get("from_user_name", {}).get("str", "")
self.to_user_name = raw_message.get("to_user_name", {}).get("str", "")
self.msg_id = raw_message.get("msg_id", "")
def _format_to_xml(self):
if self._xml:
return self._xml
try:
msg_str = self.content
if not self.is_private_chat:
parts = self.content.split(":\n", 1)
msg_str = parts[1] if len(parts) == 2 else self.content
self._xml = eT.fromstring(msg_str)
return self._xml
except Exception as e:
logger.error(f"[XML解析失败] {e}")
raise
async def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
"""
处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57
"""
try:
appmsg_type = self._format_to_xml().findtext(".//appmsg/type")
if appmsg_type == "57":
return await self.parse_reply()
except Exception as e:
logger.warning(f"[parse_mutil_49] 解析失败: {e}")
return None
async def parse_reply(self) -> list[BaseMessageComponent]:
"""
处理 type == 57 的引用消息支持文本1、图片3、嵌套4949
"""
components = []
try:
appmsg = self._format_to_xml().find("appmsg")
if appmsg is None:
return [Plain("[引用消息解析失败]")]
refermsg = appmsg.find("refermsg")
if refermsg is None:
return [Plain("[引用消息解析失败]")]
quote_type = int(refermsg.findtext("type", "0"))
nickname = refermsg.findtext("displayname", "未知发送者")
quote_content = refermsg.findtext("content", "")
svrid = refermsg.findtext("svrid")
match quote_type:
case 1: # 文本引用
quoted_text = self.cached_texts.get(str(svrid), quote_content)
components.append(Plain(f"[引用] {nickname}: {quoted_text}"))
case 3: # 图片引用
quoted_image_b64 = self.cached_images.get(str(svrid))
if not quoted_image_b64:
try:
quote_xml = eT.fromstring(quote_content)
img = quote_xml.find("img")
cdn_url = (
img.get("cdnbigimgurl") or img.get("cdnmidimgurl")
if img is not None
else None
)
if cdn_url and self.downloader:
image_resp = await self.downloader(
self.from_user_name, self.to_user_name, self.msg_id
)
quoted_image_b64 = (
image_resp.get("Data", {})
.get("Data", {})
.get("Buffer")
)
except Exception as e:
logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}")
if quoted_image_b64:
components.extend(
[
Image.fromBase64(quoted_image_b64),
Plain(f"[引用] {nickname}: [引用的图片]"),
]
)
else:
components.append(
Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]")
)
case 49: # 嵌套引用
try:
nested_root = eT.fromstring(quote_content)
nested_title = nested_root.findtext(".//appmsg/title", "")
components.append(Plain(f"[引用] {nickname}: {nested_title}"))
except Exception as e:
logger.warning(f"[嵌套引用解析失败] err={e}")
components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]"))
case _: # 其他未识别类型
logger.info(f"[未知引用类型] quote_type={quote_type}")
components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]"))
# 主消息标题
title = appmsg.findtext("title", "")
if title:
components.append(Plain(title))
except Exception as e:
logger.error(f"[parse_reply] 总体解析失败: {e}")
return [Plain("[引用消息解析失败]")]
return components
def parse_emoji(self) -> Emoji | None:
"""
处理 msg_type == 47 的表情消息emoji
"""
try:
emoji_element = self._format_to_xml().find(".//emoji")
if emoji_element is not None:
return Emoji(
md5=emoji_element.get("md5"),
md5_len=emoji_element.get("len"),
cdnurl=emoji_element.get("cdnurl"),
)
except Exception as e:
logger.error(f"[parse_emoji] 解析失败: {e}")
return None

View File

@@ -20,7 +20,7 @@ from requests import Response
from wechatpy.utils import check_signature
from wechatpy.crypto import WeChatCrypto
from wechatpy import WeChatClient
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage
from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage
from wechatpy.exceptions import InvalidSignatureException
from wechatpy import parse_message
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
@@ -87,7 +87,11 @@ class WecomServer:
logger.info(f"解析成功: {msg}")
if self.callback:
await self.callback(msg)
result_xml = await self.callback(msg)
if not result_xml:
return "success"
if isinstance(result_xml, str):
return result_xml
return "success"
@@ -117,6 +121,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.api_base_url = platform_config.get(
"api_base_url", "https://api.weixin.qq.com/cgi-bin/"
)
self.active_send_mode = self.config.get("active_send_mode", False)
if not self.api_base_url:
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
@@ -138,9 +143,29 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.client.API_BASE_URL = self.api_base_url
async def callback(msg):
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future
self.wexin_event_workers: dict[str, asyncio.Future] = {}
async def callback(msg: BaseMessage):
try:
await self.convert_message(msg)
if self.active_send_mode:
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError:
pass
except Exception as e:
logger.error(f"转换消息时出现异常: {e}")
@@ -163,7 +188,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
async def run(self):
await self.server.start_polling()
async def convert_message(self, msg) -> AstrBotMessage | None:
async def convert_message(
self, msg, future: asyncio.Future = None
) -> AstrBotMessage | None:
abm = AstrBotMessage()
if isinstance(msg, TextMessage):
abm.message_str = msg.content
@@ -177,7 +204,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message_id = msg.id
abm.timestamp = msg.time
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]"
@@ -191,7 +217,6 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message_id = msg.id
abm.timestamp = msg.time
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
@@ -209,7 +234,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
audio = AudioSegment.from_file(path)
audio.export(path_wav, format="wav")
except Exception as e:
logger.error(f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。")
logger.error(
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。"
)
path_wav = path
return
@@ -224,11 +251,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message_id = msg.id
abm.timestamp = msg.time
abm.session_id = abm.sender.user_id
abm.raw_message = msg
else:
logger.warning(f"暂未实现的事件: {msg.type}")
future.set_result(None)
return
# 很不优雅 :(
abm.raw_message = {
"message": msg,
"future": future,
"active_send_mode": self.active_send_mode,
}
logger.info(f"abm: {abm}")
await self.handle_msg(abm)

View File

@@ -4,6 +4,8 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from wechatpy import WeChatClient
from wechatpy.replies import TextReply, ImageReply, VoiceReply
from astrbot.api import logger
@@ -82,12 +84,23 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
for comp in message.chain:
if isinstance(comp, Plain):
# Split long text messages if needed
plain_chunks = await self.split_plain(comp.text)
for chunk in plain_chunks:
self.client.message.send_text(message_obj.sender.user_id, chunk)
if active_send_mode:
self.client.message.send_text(message_obj.sender.user_id, chunk)
else:
reply = TextReply(
content=chunk,
message=self.message_obj.raw_message["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
@@ -102,10 +115,22 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
)
return
logger.debug(f"微信公众平台上传图片返回: {response}")
self.client.message.send_image(
message_obj.sender.user_id,
response["media_id"],
)
if active_send_mode:
self.client.message.send_image(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = ImageReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
# 转成amr
@@ -124,10 +149,23 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
)
return
logger.info(f"微信公众平台上传语音返回: {response}")
self.client.message.send_voice(
message_obj.sender.user_id,
response["media_id"],
)
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = VoiceReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
else:
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")

View File

@@ -19,6 +19,7 @@ class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
EMBEDDING = "embedding"
@dataclass
@@ -155,7 +156,9 @@ class ProviderRequest:
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"}
],
}
for image_url in self.image_urls:
if image_url.startswith("http"):

View File

@@ -4,6 +4,7 @@ import textwrap
import os
import asyncio
import logging
from datetime import timedelta
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass
@@ -20,6 +21,13 @@ try:
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
try:
from mcp.client.streamable_http import streamablehttp_client
except (ModuleNotFoundError, ImportError):
logger.warning(
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
@@ -96,7 +104,10 @@ class MCPClient:
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
如果 `url` 参数存在
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
@@ -108,15 +119,41 @@ class MCPClient:
cfg.pop("active", None) # Remove active flag from config
if "url" in cfg:
# SSE transport method
self._streams_context = sse_client(url=cfg["url"])
streams = await self._streams_context.__aenter__()
is_sse = True
if cfg.get("transport") == "streamable_http":
is_sse = False
if is_sse:
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
streams = await self._streams_context.__aenter__()
# Create a new client session
# self.session = await self._session_context.__aenter__()
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*streams)
)
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*streams)
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
sse_read_timeout = timedelta(
seconds=cfg.get("sse_read_timeout", 60 * 5)
)
self._streams_context = streamablehttp_client(
url=cfg["url"],
headers=cfg.get("headers", {}),
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
read_s, write_s, _ = await self._streams_context.__aenter__()
# Create a new client session
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(read_stream=read_s, write_stream=write_s)
)
else:
server_params = mcp.StdioServerParameters(

View File

@@ -18,13 +18,6 @@ class ProviderManager:
self.persona_configs: list = config.get("persona", [])
self.astrbot_config = config
self.selected_provider_id = sp.get("curr_provider")
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
self.provider_enabled = self.provider_settings.get("enable", False)
self.stt_enabled = self.provider_stt_settings.get("enable", False)
self.tts_enabled = self.provider_tts_settings.get("enable", False)
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get(
@@ -98,15 +91,18 @@ class ProviderManager:
"""加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[Provider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
"""当前使用的 Provider 实例"""
"""默认的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider = None
"""当前使用的 Speech To Text Provider 实例"""
"""默认的 Speech To Text Provider 实例"""
self.curr_tts_provider_inst: TTSProvider = None
"""当前使用的 Text To Speech Provider 实例"""
"""默认的 Text To Speech Provider 实例"""
self.db_helper = db_helper
# kdb(experimental)
@@ -115,18 +111,57 @@ class ProviderManager:
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
async def set_provider(
self, provider_id: str, provider_type: ProviderType, umo: str = None
):
"""设置提供商。
Args:
provider_id (str): 提供商 ID。
provider_type (ProviderType): 提供商类型。
umo (str, optional): 用户会话 ID用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
"""
if provider_id not in self.inst_map:
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
if umo and self.provider_settings["separate_provider"]:
perf = sp.get("session_provider_perf", {})
session_perf = perf.get(umo, {})
session_perf[provider_type.value] = provider_id
perf[umo] = session_perf
sp.put("session_provider_perf", perf)
return
# 不启用提供商会话隔离模式的情况
self.curr_provider_inst = self.inst_map[provider_id]
if provider_type == ProviderType.TEXT_TO_SPEECH:
sp.put("curr_provider_tts", provider_id)
elif provider_type == ProviderType.SPEECH_TO_TEXT:
sp.put("curr_provider_stt", provider_id)
elif provider_type == ProviderType.CHAT_COMPLETION:
sp.put("curr_provider", provider_id)
async def initialize(self):
# 逐个初始化提供商
for provider_config in self.providers_config:
await self.load_provider(provider_config)
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
# 设置默认提供商
self.curr_provider_inst = self.inst_map.get(
self.provider_settings.get("default_provider_id")
)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
if self.stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
self.curr_stt_provider_inst = self.inst_map.get(
self.provider_stt_settings.get("provider_id")
)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
self.curr_tts_provider_inst = self.inst_map.get(
self.provider_tts_settings.get("provider_id")
)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接
asyncio.create_task(
@@ -210,6 +245,18 @@ class ProviderManager:
from .sources.minimax_tts_api_source import (
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
)
case "volcengine_tts":
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
case "openai_embedding":
from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
)
case "gemini_embedding":
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
@@ -242,14 +289,14 @@ class ProviderManager:
self.stt_provider_insts.append(inst)
if (
self.selected_stt_provider_id == provider_config["id"]
and self.stt_enabled
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
)
if not self.curr_stt_provider_inst and self.stt_enabled:
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
@@ -262,15 +309,12 @@ class ProviderManager:
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.selected_tts_provider_id == provider_config["id"]
and self.tts_enabled
):
if self.provider_settings.get("provider_id") == provider_config["id"]:
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
)
if not self.curr_tts_provider_inst and self.tts_enabled:
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
@@ -288,16 +332,24 @@ class ProviderManager:
self.provider_insts.append(inst)
if (
self.selected_provider_id == provider_config["id"]
and self.provider_enabled
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
)
if not self.curr_provider_inst and self.provider_enabled:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = provider_metadata.cls_type(
provider_config, self.provider_settings
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
logger.error(traceback.format_exc())
@@ -318,39 +370,24 @@ class ProviderManager:
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif (
self.curr_provider_inst is None
and len(self.provider_insts) > 0
and self.provider_enabled
):
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
self.selected_provider_id = self.curr_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif (
self.curr_stt_provider_inst is None
and len(self.stt_provider_insts) > 0
and self.stt_enabled
):
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif (
self.curr_tts_provider_inst is None
and len(self.tts_provider_insts) > 0
and self.tts_enabled
):
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
)

View File

@@ -179,3 +179,25 @@ class TTSProvider(AbstractProvider):
async def get_audio(self, text: str) -> str:
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError()
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的向量"""
...
@abc.abstractmethod
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的向量"""
...
@abc.abstractmethod
def get_dim(self) -> int:
"""获取向量的维度"""
...

View File

@@ -104,11 +104,13 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
contexts=None,
system_prompt=None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
if not prompt:
prompt = "<image>"

View File

@@ -53,8 +53,8 @@ class OTTSProvider:
async def _generate_signature(self) -> str:
await self._sync_time()
timestamp = int(time.time()) + self.time_offset
nonce = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=10))
path = re.sub(r'^https?://[^/]+', '', self.api_url) or '/'
nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
async def get_audio(self, text: str, voice_params: Dict) -> str:
@@ -92,7 +92,7 @@ class AzureNativeProvider(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict):
super().__init__(provider_config, provider_settings)
self.subscription_key = provider_config.get("azure_tts_subscription_key", "").strip()
if not re.fullmatch(r'^[a-zA-Z0-9]{32}$', self.subscription_key):
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
raise ValueError("无效的Azure订阅密钥")
self.region = provider_config.get("azure_tts_region", "eastus").strip()
self.endpoint = f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
@@ -188,7 +188,7 @@ class AzureTTSProvider(TTSProvider):
raise ValueError(error_msg) from e
except KeyError as e:
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
if re.fullmatch(r'^[a-zA-Z0-9]{32}$', key_value):
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
return AzureNativeProvider(config, self.provider_settings)
raise ValueError("订阅密钥格式无效应为32位字母数字或other[...]格式")

View File

@@ -74,6 +74,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量

View File

@@ -61,12 +61,14 @@ class ProviderDify(Provider):
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
image_urls = []
result = ""
conversation_id = self.conversation_ids.get(session_id, "")

View File

@@ -0,0 +1,63 @@
from google import genai
from google.genai import types
from google.genai.errors import APIError
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
from ..entities import ProviderType
@register_provider_adapter(
"gemini_embedding",
"Google Gemini Embedding 提供商适配器",
provider_type=ProviderType.EMBEDDING,
)
class GeminiEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key")
api_base: str = provider_config.get("embedding_api_base", None)
timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
if api_base:
if api_base.endswith("/"):
api_base = api_base[:-1]
http_options.base_url = api_base
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
self.model = provider_config.get(
"embedding_model", "gemini-embedding-exp-03-07"
)
self.dimension = provider_config.get("embedding_dimensions", 768)
async def get_embedding(self, text: str) -> list[float]:
"""
获取文本的嵌入
"""
try:
result = await self.client.models.embed_content(
model=self.model, contents=text
)
return result.embeddings[0].values
except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
批量获取文本的嵌入
"""
try:
result = await self.client.models.embed_content(
model=self.model, contents=texts
)
return [embedding.values for embedding in result.embeddings]
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension

View File

@@ -141,24 +141,66 @@ class ProviderGoogleGenAI(Provider):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"]
tool_list = None
tool_list = []
model_name = self.get_model()
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False)
url_context = self.provider_config.get("gm_url_context", False)
if native_coderunner:
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
if native_search:
logger.warning("已启用代码执行工具,搜索工具将被忽略")
if tools:
logger.warning("已启用代码执行工具,函数工具将被忽略")
elif native_search:
tool_list = [types.Tool(google_search=types.GoogleSearch())]
if tools:
logger.warning("已启用搜索工具,函数工具将被忽略")
if "gemini-2.5" in model_name:
if native_coderunner:
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
if native_search:
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
if url_context:
logger.warning(
"代码执行工具与URL上下文工具互斥已忽略URL上下文工具"
)
else:
if native_search:
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
if url_context:
if hasattr(types, "UrlContext"):
tool_list.append(types.Tool(url_context=types.UrlContext()))
else:
logger.warning(
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
)
elif "gemini-2.0-lite" in model_name:
if native_coderunner or native_search or url_context:
logger.warning(
"gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文将忽略这些设置"
)
tool_list = None
else:
if native_coderunner:
tool_list.append(types.Tool(code_execution=types.ToolCodeExecution()))
if native_search:
logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具")
elif native_search:
tool_list.append(types.Tool(google_search=types.GoogleSearch()))
if url_context and not native_coderunner:
if hasattr(types, "UrlContext"):
tool_list.append(types.Tool(url_context=types.UrlContext()))
else:
logger.warning(
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包"
)
if not tool_list:
tool_list = None
if tools and tool_list:
logger.warning("已启用原生工具,函数工具将被忽略")
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"])
]
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
@@ -291,19 +333,19 @@ class ProviderGoogleGenAI(Provider):
result_parts: Optional[types.Part] = result.candidates[0].content.parts
if finish_reason == types.FinishReason.SAFETY:
raise Exception("模型生成内容未通过用户定义的内容安全检查")
raise Exception("模型生成内容未通过 Gemini 平台的安全检查")
if finish_reason in {
types.FinishReason.PROHIBITED_CONTENT,
types.FinishReason.SPII,
types.FinishReason.BLOCKLIST,
}:
raise Exception("模型生成内容违反Gemini平台政策")
raise Exception("模型生成内容违反 Gemini 平台政策")
# 防止旧版本SDK不存在IMAGE_SAFETY
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
if finish_reason == types.FinishReason.IMAGE_SAFETY:
raise Exception("模型生成内容违反Gemini平台政策")
raise Exception("模型生成内容违反 Gemini 平台政策")
if not result_parts:
logger.debug(result.candidates)

View File

@@ -60,10 +60,12 @@ class LLMTunerModelLoader(Provider):
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = [],
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
system_prompt = ""
new_record = {"role": "user", "content": prompt}
query_context = [*contexts, new_record]

View File

@@ -0,0 +1,43 @@
from openai import AsyncOpenAI
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
from ..entities import ProviderType
@register_provider_adapter(
"openai_embedding",
"OpenAI API Embedding 提供商适配器",
provider_type=ProviderType.EMBEDDING,
)
class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings
self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
base_url=provider_config.get(
"embedding_api_base", "https://api.openai.com/v1"
),
timeout=int(provider_config.get("timeout", 20)),
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
self.dimension = provider_config.get("embedding_dimensions", 1536)
async def get_embedding(self, text: str) -> list[float]:
"""
获取文本的嵌入
"""
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
批量获取文本的嵌入
"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
"""获取向量的维度"""
return self.dimension

View File

@@ -195,7 +195,11 @@ class ProviderOpenAIOfficial(Provider):
for tool_call in choice.message.tool_calls:
for tool in tools.func_list:
if tool.name == tool_call.function.name:
args = json.loads(tool_call.function.arguments)
# workaround for #1454
if isinstance(tool_call.function.arguments, str):
args = json.loads(tool_call.function.arguments)
else:
args = tool_call.function.arguments
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
@@ -223,9 +227,9 @@ class ProviderOpenAIOfficial(Provider):
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: list=None,
system_prompt: str=None,
tool_calls_result: ToolCallsResult=None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -340,9 +344,9 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat(
self,
prompt,
session_id = None,
image_urls = None,
func_tool = None,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,

View File

@@ -0,0 +1,107 @@
import uuid
import base64
import json
import os
import traceback
import asyncio
import aiohttp
import requests
from ..provider import TTSProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot import logger
@register_provider_adapter(
"volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
)
class ProviderVolcengineTTS(TTSProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.api_key = provider_config.get("api_key", "")
self.appid = provider_config.get("appid", "")
self.cluster = provider_config.get("volcengine_cluster", "")
self.voice_type = provider_config.get("volcengine_voice_type", "")
self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
self.api_base = provider_config.get("api_base", f"https://openspeech.bytedance.com/api/v1/tts")
self.timeout = provider_config.get("timeout", 20)
def _build_request_payload(self, text: str) -> dict:
return {
"app": {
"appid": self.appid,
"token": self.api_key,
"cluster": self.cluster
},
"user": {
"uid": str(uuid.uuid4())
},
"audio": {
"voice_type": self.voice_type,
"encoding": "mp3",
"speed_ratio": self.speed_ratio,
"volume_ratio": 1.0,
"pitch_ratio": 1.0,
},
"request": {
"reqid": str(uuid.uuid4()),
"text": text,
"text_type": "plain",
"operation": "query",
"with_frontend": 1,
"frontend_type": "unitTson"
}
}
async def get_audio(self, text: str) -> str:
"""异步方法获取语音文件路径"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer; {self.api_key}"
}
payload = self._build_request_payload(text)
logger.debug(f"请求头: {headers}")
logger.debug(f"请求 URL: {self.api_base}")
logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
try:
async with aiohttp.ClientSession() as session:
async with session.post(
self.api_base,
data=json.dumps(payload),
headers=headers,
timeout=self.timeout
) as response:
logger.debug(f"响应状态码: {response.status}")
response_text = await response.text()
logger.debug(f"响应内容: {response_text[:200]}...")
if response.status == 200:
resp_data = json.loads(response_text)
if "data" in resp_data:
audio_data = base64.b64decode(resp_data["data"])
os.makedirs("data/temp", exist_ok=True)
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: open(file_path, "wb").write(audio_data)
)
return file_path
else:
error_msg = resp_data.get("message", "未知错误")
raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
else:
raise Exception(f"火山引擎 TTS API 请求失败: {response.status}, {response_text}")
except Exception as e:
error_details = traceback.format_exc()
logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
raise Exception(f"火山引擎 TTS 异常: {str(e)}")

View File

@@ -31,10 +31,12 @@ class ProviderZhipu(ProviderOpenAIOfficial):
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts=[],
contexts=None,
system_prompt=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = await self.assemble_context(prompt, image_urls)
context_query = []

View File

@@ -1,20 +0,0 @@
from typing import List
from openai import AsyncOpenAI
class SimpleOpenAIEmbedding:
def __init__(
self,
model,
api_key,
api_base=None,
) -> None:
self.client = AsyncOpenAI(api_key=api_key, base_url=api_base)
self.model = model
async def get_embedding(self, text) -> List[float]:
"""
获取文本的嵌入
"""
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding

View File

@@ -1,95 +0,0 @@
import os
from typing import List, Dict
from astrbot.core import logger
from .store import Store
from astrbot.core.config import AstrBotConfig
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class KnowledgeDBManager:
def __init__(self, astrbot_config: AstrBotConfig) -> None:
self.db_path = os.path.join(get_astrbot_data_path(), "knowledge_db")
self.config = astrbot_config.get("knowledge_db", {})
self.astrbot_config = astrbot_config
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
self.store_insts: Dict[str, Store] = {}
for name, cfg in self.config.items():
if cfg["strategy"] == "embedding":
logger.info(f"加载 Chroma Vector Store{name}")
try:
from .store.chroma_db import ChromaVectorStore
except ImportError as ie:
logger.error(f"{ie} 可能未安装 chromadb 库。")
continue
self.store_insts[name] = ChromaVectorStore(
name, cfg["embedding_config"]
)
else:
logger.error(f"不支持的策略:{cfg['strategy']}")
async def list_knowledge_db(self) -> List[str]:
return [
f
for f in os.listdir(self.db_path)
if os.path.isfile(os.path.join(self.db_path, f))
]
async def create_knowledge_db(self, name: str, config: Dict):
"""
config 格式:
```
{
"strategy": "embedding", # 目前只支持 embedding
"chunk_method": {
"strategy": "fixed",
"chunk_size": 100,
"overlap_size": 10
},
"embedding_config": {
"strategy": "openai",
"base_url": "",
"model": "",
"api_key": ""
}
}
```
"""
if name in self.config:
raise ValueError(f"知识库已存在:{name}")
self.config[name] = config
self.astrbot_config["knowledge_db"] = self.config
self.astrbot_config.save_config()
async def insert_record(self, name: str, text: str):
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
ret = []
match self.config[name]["chunk_method"]["strategy"]:
case "fixed":
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
case _:
pass
for chunk in ret:
await self.store_insts[name].save(chunk)
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
inst = self.store_insts[name]
return await inst.query(query, top_n)
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - chunk_overlap
return chunks

View File

@@ -1,9 +0,0 @@
from typing import List
class Store:
async def save(self, text: str):
pass
async def query(self, query: str, top_n: int = 3) -> List[str]:
pass

View File

@@ -1,44 +0,0 @@
import chromadb
import uuid
from typing import List, Dict
from astrbot.api import logger
from ..embedding.openai_source import SimpleOpenAIEmbedding
from . import Store
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class ChromaVectorStore(Store):
def __init__(self, name: str, embedding_cfg: Dict) -> None:
import os
self.chroma_client = chromadb.PersistentClient(
path=os.path.join(get_astrbot_data_path(), "long_term_memory_chroma.db")
)
self.collection = self.chroma_client.get_or_create_collection(name=name)
self.embedding = None
if embedding_cfg["strategy"] == "openai":
self.embedding = SimpleOpenAIEmbedding(
model=embedding_cfg["model"],
api_key=embedding_cfg["api_key"],
api_base=embedding_cfg.get("base_url", None),
)
async def save(self, text: str, metadata: Dict = None):
logger.debug(f"Saving text: {text}")
embedding = await self.embedding.get_embedding(text)
self.collection.upsert(
documents=text,
metadatas=metadata,
ids=str(uuid.uuid4()),
embeddings=embedding,
)
async def query(
self, query: str, top_n=3, metadata_filter: Dict = None
) -> List[str]:
embedding = await self.embedding.get_embedding(query)
results = self.collection.query(
query_embeddings=embedding, n_results=top_n, where=metadata_filter
)
return results["documents"][0]

View File

@@ -3,6 +3,7 @@ from typing import List, Union
from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.func_tool_manager import FuncCall
@@ -16,7 +17,6 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
@@ -42,6 +42,8 @@ class Context:
platform_manager: PlatformManager = None
registered_web_apis: list = []
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
@@ -54,14 +56,12 @@ class Context:
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
knowledge_db_manager: KnowledgeDBManager = None,
):
self._event_queue = event_queue
self._config = config
self._db = db
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
self.conversation_manager = conversation_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
@@ -126,11 +126,8 @@ class Context:
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
"""通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
def get_all_providers(self) -> List[Provider]:
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
@@ -144,24 +141,46 @@ class Context:
"""获取所有用于 STT 任务的 Provider。"""
return self.provider_manager.stt_provider_insts
def get_using_provider(self) -> Provider:
def get_using_provider(self, umo: str = None) -> Provider:
"""
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
通过 /provider 指令切换。
Args:
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
"""
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_provider_inst
def get_using_tts_provider(self) -> TTSProvider:
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
"""
获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_tts_provider_inst
def get_using_stt_provider(self) -> STTProvider:
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
"""
获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
if umo and self._config["provider_settings"]["separate_provider"]:
perf = sp.get("session_provider_perf", {})
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
if inst := self.provider_manager.inst_map.get(prov_id, None):
return inst
return self.provider_manager.curr_stt_provider_inst
def get_config(self) -> AstrBotConfig:
@@ -301,3 +320,12 @@ class Context:
注册一个异步任务。
"""
self._register_tasks.append(task)
def register_web_api(
self, route: str, view_handler: Awaitable, methods: list, desc: str
):
for idx, api in enumerate(self.registered_web_apis):
if api[0] == route and methods == api[2]:
self.registered_web_apis[idx] = (route, view_handler, methods, desc)
return
self.registered_web_apis.append((route, view_handler, methods, desc))

View File

@@ -7,6 +7,9 @@ from astrbot.core.config import AstrBotConfig
from .custom_filter import CustomFilter
from ..star_handler import StarHandlerMetadata
class GreedyStr(str):
"""标记指令完成其他参数接收后的所有剩余文本。"""
pass
# 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter):
@@ -68,7 +71,22 @@ class CommandFilter(HandlerFilter):
) -> Dict[str, Any]:
"""将参数列表 params 根据 param_type 转换为参数字典。"""
result = {}
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
param_items = list(param_type.items())
for i, (param_name, param_type_or_default_val) in enumerate(param_items):
is_greedy = param_type_or_default_val is GreedyStr
if is_greedy:
# GreedyStr 必须是最后一个参数
if i != len(param_items) - 1:
raise ValueError(
f"参数 '{param_name}' (GreedyStr) 必须是最后一个参数。"
)
# 将剩余的所有部分合并成一个字符串
remaining_params = params[i:]
result[param_name] = " ".join(remaining_params)
break
# 没有 GreedyStr 的情况
if i >= len(params):
if (
isinstance(param_type_or_default_val, Type)

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import enum
import heapq
from dataclasses import dataclass, field
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
@@ -8,100 +7,66 @@ from .star import star_map
T = TypeVar("T", bound="StarHandlerMetadata")
class StarHandlerRegistry(Generic[T]):
"""用于存储所有的 Star Handler"""
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
"""用于快速查找。key 是 handler_full_name"""
_handlers = []
def __init__(self):
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
self._handlers: List[StarHandlerMetadata] = []
def append(self, handler: StarHandlerMetadata):
"""添加一个 Handler"""
"""添加一个 Handler,并保持按优先级有序"""
if "priority" not in handler.extras_configs:
handler.extras_configs["priority"] = 0
heapq.heappush(self._handlers, (-handler.extras_configs["priority"], handler))
self.star_handlers_map[handler.handler_full_name] = handler
self._handlers.append(handler)
self._handlers.sort(key=lambda h: -h.extras_configs["priority"])
def _print_handlers(self):
"""打印所有的 Handler"""
for _, handler in self._handlers:
for handler in self._handlers:
print(handler.handler_full_name)
def get_handlers_by_event_type(
self, event_type: EventType, only_activated=True, platform_id=None
) -> List[StarHandlerMetadata]:
"""通过事件类型获取 Handler
Args:
event_type: 事件类型
only_activated: 是否只返回已激活的插件的处理器
platform_id: 平台ID如果提供此参数将过滤掉在此平台不兼容的处理器
Returns:
List[StarHandlerMetadata]: 处理器列表
"""
handlers = []
for _, handler in self._handlers:
for handler in self._handlers:
if handler.event_type != event_type:
continue
# 只激活的插件处理器
if only_activated:
plugin = star_map.get(handler.handler_module_path)
if not (plugin and plugin.activated):
continue
# 平台兼容性过滤
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
if not handler.is_enabled_for_platform(platform_id):
continue
handlers.append(handler)
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
"""通过 Handler 的全名获取 Handler"""
return self.star_handlers_map.get(full_name, None)
def get_handlers_by_module_name(
self, module_name: str
) -> List[StarHandlerMetadata]:
"""通过模块名获取 Handler"""
return [
handler
for _, handler in self._handlers
handler for handler in self._handlers
if handler.handler_module_path == module_name
]
def clear(self):
"""清空所有的 Handler"""
self.star_handlers_map.clear()
self._handlers.clear()
def remove(self, handler: StarHandlerMetadata):
"""删除一个 Handler"""
# self._handlers.remove(handler)
for i, h in enumerate(self._handlers):
if h[1] == handler:
self._handlers.pop(i)
break
try:
del self.star_handlers_map[handler.handler_full_name]
except KeyError:
pass
self.star_handlers_map.pop(handler.handler_full_name, None)
self._handlers = [h for h in self._handlers if h != handler]
def __iter__(self):
"""使 StarHandlerRegistry 支持迭代"""
return (handler for _, handler in self._handlers)
return iter(self._handlers)
def __len__(self):
"""返回 Handler 的数量"""
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()

View File

@@ -37,6 +37,12 @@ except ImportError:
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
logger.warning("未安装 watchfiles无法实现插件的热重载。")
try:
import nh3
except ImportError:
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
nh3 = None
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig):
@@ -140,11 +146,13 @@ class PluginManager:
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
os.path.join(path, d, d + ".py")
):
modules.append({
"pname": d,
"module": module_str,
"module_path": os.path.join(path, d, module_str),
})
modules.append(
{
"pname": d,
"module": module_str,
"module_path": os.path.join(path, d, module_str),
}
)
return modules
def _get_plugin_modules(self) -> List[dict]:
@@ -158,7 +166,7 @@ class PluginManager:
plugins.extend(_p)
return plugins
def _check_plugin_dept_update(self, target_plugin: str = None):
async def _check_plugin_dept_update(self, target_plugin: str = None):
"""检查插件的依赖
如果 target_plugin 为 None则检查所有插件的依赖
"""
@@ -177,7 +185,7 @@ class PluginManager:
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
try:
pip_installer.install(requirements_path=pth)
await pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
@@ -399,7 +407,7 @@ class PluginManager:
module = __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError):
# 尝试安装依赖
self._check_plugin_dept_update(target_plugin=root_dir_name)
await self._check_plugin_dept_update(target_plugin=root_dir_name)
module = __import__(path, fromlist=[module_str])
except Exception as e:
logger.error(traceback.format_exc())
@@ -443,11 +451,11 @@ class PluginManager:
metadata.repo = metadata_yaml.repo
except Exception:
pass
metadata.config = plugin_config
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
if plugin_config:
metadata.config = plugin_config
# metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(
context=self.context, config=plugin_config
@@ -634,16 +642,17 @@ class PluginManager:
if not os.path.exists(readme_path):
readme_path = os.path.join(plugin_path, "readme.md")
if os.path.exists(readme_path):
if os.path.exists(readme_path) and nh3:
try:
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
cleaned_content = nh3.clean(readme_content)
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
plugin_info = None
if plugin:
plugin_info = {"repo": plugin.repo, "readme": readme_content}
plugin_info = {"repo": plugin.repo, "readme": cleaned_content}
return plugin_info

View File

@@ -18,7 +18,8 @@ class PluginUpdator(RepoZipUpdator):
return self.plugin_store_path
async def install(self, repo_url: str, proxy="") -> str:
repo_name = self.format_repo_name(repo_url)
_, repo_name, _ = self.parse_github_url(repo_url)
repo_name = self.format_name(repo_name)
plugin_path = os.path.join(self.plugin_store_path, repo_name)
await self.download_from_repo_url(plugin_path, repo_url, proxy)
self.unzip_file(plugin_path + ".zip", plugin_path)
@@ -31,10 +32,6 @@ class PluginUpdator(RepoZipUpdator):
if not repo_url:
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
if proxy:
proxy = proxy.removesuffix("/")
repo_url = f"{proxy}/{repo_url}"
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
@@ -54,7 +51,7 @@ class PluginUpdator(RepoZipUpdator):
def unzip_file(self, zip_path: str, target_dir: str):
os.makedirs(target_dir, exist_ok=True)
update_dir = ""
logger.info(f"解压文件: {zip_path}")
logger.info(f"正在解压压缩包: {zip_path}")
with zipfile.ZipFile(zip_path, "r") as z:
update_dir = z.namelist()[0]
z.extractall(target_dir)

View File

@@ -1,5 +1,5 @@
import logging
from pip import main as pip_main
import asyncio
logger = logging.getLogger("astrbot")
@@ -9,7 +9,7 @@ class PipInstaller:
self.pip_install_arg = pip_install_arg
self.pypi_index_url = pypi_index_url
def install(
async def install(
self,
package_name: str = None,
requirements_path: str = None,
@@ -29,12 +29,29 @@ class PipInstaller:
args.extend(self.pip_install_arg.split())
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
try:
process = await asyncio.create_subprocess_exec(
"pip", *args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
result_code = pip_main(args)
assert process.stdout is not None
async for line in process.stdout:
logger.info(line.decode().strip())
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
await process.wait()
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
if process.returncode != 0:
raise Exception(f"安装失败,错误码:{process.returncode}")
except FileNotFoundError:
# 没有 pip
from pip import main as pip_main
result_code = await asyncio.to_thread(pip_main, args)
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")

View File

@@ -1,5 +1,10 @@
import base64
import wave
import os
from io import BytesIO
import asyncio
import tempfile
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
@@ -50,3 +55,46 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
rate = wav.getframerate()
duration = pilk.encode(wav_path, output_path, pcm_rate=rate, tencent=True)
return duration
async def wav_to_tencent_silk_base64(wav_path: str) -> str:
"""
将 WAV 文件转为 Silk并返回 Base64 字符串。
默认采样率为 24000输出临时文件为 temp/output.silk。
参数:
- wav_path: 输入 .wav 文件路径(需为 PCM 16bit
返回:
- Base64 编码的 Silk 字符串
- duration: 音频时长(秒)
"""
try:
import pilk
except ImportError as e:
raise Exception("pysilk 模块未安装,请安装 pysilk") from e
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
with wave.open(wav_path, "rb") as wav:
rate = wav.getframerate()
with tempfile.NamedTemporaryFile(
suffix=".silk", delete=False, dir=temp_dir
) as tmp_file:
silk_path = tmp_file.name
try:
duration = await asyncio.to_thread(
pilk.encode, wav_path, silk_path, pcm_rate=rate, tencent=True
)
with open(silk_path, "rb") as f:
silk_bytes = await asyncio.to_thread(f.read)
silk_b64 = base64.b64encode(silk_bytes).decode("utf-8")
return silk_b64, duration # 已是秒
finally:
if os.path.exists(silk_path):
os.remove(silk_path)

View File

@@ -1,5 +1,6 @@
import aiohttp
import os
import re
import zipfile
import shutil
@@ -119,28 +120,61 @@ class RepoZipUpdator:
)
async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""):
repo_namespace = repo_url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
author, repo, branch = self.parse_github_url(repo_url)
logger.info(f"正在下载更新 {repo} ...")
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.info(f"正在从默认分支下载 {author}/{repo} ")
if branch:
logger.info(f"正在从指定分支 {branch} 下载 {author}/{repo}")
release_url = (
f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
f"https://github.com/{author}/{repo}/archive/refs/heads/{branch}.zip"
)
else:
release_url = releases[0]["zipball_url"]
try:
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
releases = await self.fetch_release_info(url=release_url)
except Exception as e:
logger.warning(
f"获取 {author}/{repo} 的 GitHub Releases 失败: {e},将尝试下载默认分支"
)
releases = []
if not releases:
# 如果没有最新版本,下载默认分支
logger.info(f"正在从默认分支下载 {author}/{repo}")
release_url = (
f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
)
else:
release_url = releases[0]["zipball_url"]
if proxy:
proxy = proxy.rstrip("/")
release_url = f"{proxy}/{release_url}"
logger.info(f"使用代理下载: {release_url}")
logger.info(
f"检查到设置了镜像站,将使用镜像站下载 {author}/{repo} 仓库源码: {release_url}"
)
await download_file(release_url, target_path + ".zip")
def parse_github_url(self, url: str):
"""使用正则表达式解析 GitHub 仓库 URL支持 `.git` 后缀和 `tree/branch` 结构
Returns:
tuple[str, str, str]: 返回作者名、仓库名和分支名
Raises:
ValueError: 如果 URL 格式不正确
"""
cleaned_url = url.rstrip("/")
pattern = r"^https://github\.com/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)(\.git)?(?:/tree/([a-zA-Z0-9_-]+))?$"
match = re.match(pattern, cleaned_url)
if match:
author = match.group(1)
repo = match.group(2)
branch = match.group(4)
return author, repo, branch
else:
raise ValueError("无效的 GitHub URL")
def unzip_file(self, zip_path: str, target_dir: str):
"""
解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir
@@ -174,16 +208,5 @@ class RepoZipUpdator:
f"删除更新文件失败,可以手动删除 {zip_path}{os.path.join(target_dir, update_dir)}"
)
def format_repo_name(self, repo_url: str) -> str:
if repo_url.endswith("/"):
repo_url = repo_url[:-1]
repo_namespace = repo_url.split("/")[-2:]
repo = repo_namespace[1]
repo = self.format_name(repo)
return repo
def format_name(self, name: str) -> str:
return name.replace("-", "_").lower()

View File

@@ -1,5 +1,6 @@
import jwt
import datetime
import asyncio
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core import WEBUI_SK, DEMO_MODE
@@ -21,7 +22,11 @@ class AuthRoute(Route):
post_data = await request.json
if post_data["username"] == username and post_data["password"] == password:
change_pwd_hint = False
if username == "astrbot" and password == "77b90590a8945a7d36c963981a307dc9":
if (
username == "astrbot"
and password == "77b90590a8945a7d36c963981a307dc9"
and not DEMO_MODE
):
change_pwd_hint = True
logger.warning("为了保证安全,请尽快修改默认密码。")
@@ -37,6 +42,7 @@ class AuthRoute(Route):
.__dict__
)
else:
await asyncio.sleep(3)
return Response().error("用户名或密码错误").__dict__
async def edit_account(self):
@@ -72,7 +78,7 @@ class AuthRoute(Route):
def generate_jwt(self, username):
payload = {
"username": username,
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=30),
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
}
token = jwt.encode(payload, WEBUI_SK, algorithm="HS256")
return token

View File

@@ -26,6 +26,7 @@ class ChatRoute(Route):
"/chat/conversations": ("GET", self.get_conversations),
"/chat/get_conversation": ("GET", self.get_conversation),
"/chat/delete_conversation": ("GET", self.delete_conversation),
"/chat/rename_conversation": ("POST", self.rename_conversation),
"/chat/get_file": ("GET", self.get_file),
"/chat/post_image": ("POST", self.post_image),
"/chat/post_file": ("POST", self.post_file),
@@ -61,16 +62,25 @@ class ChatRoute(Route):
return Response().error("Missing key: filename").__dict__
try:
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
if filename.endswith(".wav"):
file_path = os.path.join(self.imgs_dir, os.path.basename(filename))
real_file_path = os.path.realpath(file_path)
real_imgs_dir = os.path.realpath(self.imgs_dir)
if not real_file_path.startswith(real_imgs_dir):
return Response().error("Invalid file path").__dict__
with open(real_file_path, "rb") as f:
filename_ext = os.path.splitext(filename)[1].lower()
if filename_ext == ".wav":
return QuartResponse(f.read(), mimetype="audio/wav")
elif filename.split(".")[-1] in self.supported_imgs:
elif filename_ext[1:] in self.supported_imgs:
return QuartResponse(f.read(), mimetype="image/jpeg")
else:
return QuartResponse(f.read())
except FileNotFoundError:
return Response().error("File not found").__dict__
except (FileNotFoundError, OSError):
return Response().error("File access error").__dict__
async def post_image(self):
post_data = await request.files
@@ -91,7 +101,6 @@ class ChatRoute(Route):
file = post_data["file"]
filename = f"{str(uuid.uuid4())}"
print(file)
# 通过文件格式判断文件类型
if file.content_type.startswith("audio"):
filename += ".wav"
@@ -143,7 +152,7 @@ class ChatRoute(Route):
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
logger.error(f"Failed to parse conversation history: {e}")
history = []
new_his = {"type": "user", "message": message}
if image_url:
@@ -197,6 +206,9 @@ class ChatRoute(Route):
if streaming and type != "end":
continue
if type == "update_title":
continue
if result_text:
conversation = self.db.get_conversation_by_user_id(
username, cid
@@ -204,7 +216,7 @@ class ChatRoute(Route):
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
logger.error(f"Failed to parse conversation history: {e}")
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
@@ -242,6 +254,18 @@ class ChatRoute(Route):
self.db.new_conversation(username, conversation_id)
return Response().ok(data={"conversation_id": conversation_id}).__dict__
async def rename_conversation(self):
username = g.get("username", "guest")
post_data = await request.json
if "conversation_id" not in post_data or "title" not in post_data:
return Response().error("Missing key: conversation_id or title").__dict__
conversation_id = post_data["conversation_id"]
title = post_data["title"]
self.db.update_conversation_title(username, conversation_id, title=title)
return Response().ok(message="重命名成功!").__dict__
async def get_conversations(self):
username = g.get("username", "guest")
conversations = self.db.get_conversations(username)

View File

@@ -9,6 +9,7 @@ from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
import asyncio
def try_cast(value: str, type_: str):
@@ -153,6 +154,7 @@ class ConfigRoute(Route):
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.config: AstrBotConfig = core_lifecycle.astrbot_config
self.routes = {
"/config/get": ("GET", self.get_configs),
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
@@ -164,9 +166,125 @@ class ConfigRoute(Route):
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
"/config/llmtools": ("GET", self.get_llm_tools),
"/config/provider/check_status": ("GET", self.check_all_providers_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/get_session_seperate": (
"GET",
lambda: Response()
.ok({"enable": self.config["provider_settings"]["separate_provider"]})
.__dict__,
),
"/config/provider/set_session_seperate": (
"POST",
self.post_session_seperate,
),
}
self.register_routes()
async def _test_single_provider(self, provider):
"""辅助函数:测试单个 provider 的可用性"""
meta = provider.meta()
provider_name = provider.provider_config.get("id", "Unknown Provider")
logger.debug(f"Got provider meta: {meta}")
if not provider_name and meta:
provider_name = meta.id
elif not provider_name:
provider_name = "Unknown Provider"
status_info = {
"id": getattr(meta, "id", "Unknown ID"),
"model": getattr(meta, "model", "Unknown Model"),
"type": getattr(meta, "type", "Unknown Type"),
"name": provider_name,
"status": "unavailable", # 默认为不可用
"error": None,
}
logger.debug(
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
)
try:
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
)
logger.debug(f"Received response from {status_info['name']}: {response}")
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if hasattr(response, "completion_text") and response.completion_text:
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
)
else:
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
)
except asyncio.TimeoutError:
status_info["error"] = (
"Connection timed out after 45 seconds during test call."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
)
except Exception as e:
error_message = str(e)
status_info["error"] = error_message
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
)
logger.debug(
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
)
return status_info
async def check_all_providers_status(self):
"""
API 接口: 检查所有 LLM Providers 的状态
"""
logger.info("API call received: /config/provider/check_status")
try:
all_providers: typing.List = (
self.core_lifecycle.star_context.get_all_providers()
)
logger.debug(f"Found {len(all_providers)} providers to check.")
if not all_providers:
logger.info("No providers found to check.")
return Response().ok([]).__dict__
tasks = [self._test_single_provider(p) for p in all_providers]
logger.debug(f"Created {len(tasks)} tasks for concurrent provider checks.")
results = await asyncio.gather(*tasks)
logger.info(f"Provider status check completed. Results: {results}")
return Response().ok(results).__dict__
except Exception as e:
logger.error(f"Critical error in check_all_providers_status: {str(e)}")
logger.error(traceback.format_exc())
return (
Response().error(f"检查 Provider 状态时发生严重错误: {str(e)}").__dict__
)
async def get_configs(self):
# plugin_name 为空时返回 AstrBot 配置
# 否则返回指定 plugin_name 的插件配置
@@ -175,6 +293,32 @@ class ConfigRoute(Route):
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
async def post_session_seperate(self):
"""设置提供商会话隔离"""
post_config = await request.json
enable = post_config.get("enable", None)
if enable is None:
return Response().error("缺少参数 enable").__dict__
astrbot_config = self.core_lifecycle.astrbot_config
astrbot_config["provider_settings"]["separate_provider"] = enable
try:
astrbot_config.save_config()
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "设置成功~").__dict__
async def get_provider_config_list(self):
provider_type = request.args.get("provider_type", None)
if not provider_type:
return Response().error("缺少参数 provider_type").__dict__
provider_list = []
astrbot_config = self.core_lifecycle.astrbot_config
for provider in astrbot_config["provider"]:
if provider.get("provider_type", None) == provider_type:
provider_list.append(provider)
return Response().ok(provider_list).__dict__
async def post_astrbot_configs(self):
post_configs = await request.json
try:

View File

@@ -23,6 +23,7 @@ class LogRoute(Route):
**message, # see astrbot/core/log.py
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.07) # 控制发送频率,避免过快
except asyncio.CancelledError:
pass
except BaseException as e:

View File

@@ -18,6 +18,12 @@ from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType
from astrbot.core import DEMO_MODE
try:
import nh3
except ImportError:
logger.warning("未安装 nh3 库,无法清理插件 README.md 中的 HTML 标签。")
nh3 = None
class PluginRoute(Route):
def __init__(
@@ -102,7 +108,10 @@ class PluginRoute(Route):
async def get_plugins(self):
_plugin_resp = []
plugin_name = request.args.get("name")
for plugin in self.plugin_manager.context.get_all_stars():
if plugin_name and plugin.name != plugin_name:
continue
_t = {
"name": plugin.name,
"repo": "" if plugin.repo is None else plugin.repo,
@@ -145,9 +154,7 @@ class PluginRoute(Route):
if handler.event_type == EventType.AdapterMessageEvent:
# 处理平台适配器消息事件
has_admin = False
for (
filter
) in (
for filter in (
handler.event_filters
): # 正常handler就只有 1~2 个 filter因此这里时间复杂度不会太高
if isinstance(filter, CommandFilter):
@@ -325,6 +332,9 @@ class PluginRoute(Route):
return Response().error(str(e)).__dict__
async def get_plugin_readme(self):
if not nh3:
return Response().error("未安装 nh3 库").__dict__
plugin_name = request.args.get("name")
logger.debug(f"正在获取插件 {plugin_name} 的README文件内容")
@@ -360,9 +370,11 @@ class PluginRoute(Route):
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
cleaned_content = nh3.clean(readme_content)
return (
Response()
.ok({"content": readme_content}, "成功获取README内容")
.ok({"content": cleaned_content}, "成功获取README内容")
.__dict__
)
except Exception as e:
@@ -383,14 +395,12 @@ class PluginRoute(Route):
platform_type = platform.get("type", "")
platform_id = platform.get("id", "")
platforms.append(
{
"name": platform_id, # 使用type作为name这是系统内部使用的平台名称
"id": platform_id, # 保留id字段以便前端可以显示
"type": platform_type,
"display_name": f"{platform_type}({platform_id})",
}
)
platforms.append({
"name": platform_id, # 使用type作为name这是系统内部使用的平台名称
"id": platform_id, # 保留id字段以便前端可以显示
"type": platform_type,
"display_name": f"{platform_type}({platform_id})",
})
adjusted_platform_enable = {}
for platform_id, plugins in platform_enable.items():
@@ -399,13 +409,11 @@ class PluginRoute(Route):
# 获取所有插件,包括系统内部插件
plugins = []
for plugin in self.plugin_manager.context.get_all_stars():
plugins.append(
{
"name": plugin.name,
"desc": plugin.desc,
"reserved": plugin.reserved, # 添加reserved标志
}
)
plugins.append({
"name": plugin.name,
"desc": plugin.desc,
"reserved": plugin.reserved, # 添加reserved标志
})
logger.debug(
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
@@ -413,13 +421,11 @@ class PluginRoute(Route):
return (
Response()
.ok(
{
"platforms": platforms,
"plugins": plugins,
"platform_enable": adjusted_platform_enable,
}
)
.ok({
"platforms": platforms,
"plugins": plugins,
"platform_enable": adjusted_platform_enable,
})
.__dict__
)
except Exception as e:

View File

@@ -8,6 +8,7 @@ from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.config import VERSION
from astrbot.core.utils.io import get_dashboard_version
from astrbot.core import DEMO_MODE
@@ -45,8 +46,27 @@ class StatRoute(Route):
h, m = divmod(m, 60)
return f"{h}小时{m}{s}"
def is_default_cred(self):
username = self.config["dashboard"]["username"]
password = self.config["dashboard"]["password"]
return (
username == "astrbot"
and password == "77b90590a8945a7d36c963981a307dc9"
and not DEMO_MODE
)
async def get_version(self):
return Response().ok({"version": VERSION}).__dict__
return (
Response()
.ok(
{
"version": VERSION,
"dashboard_version": await get_dashboard_version(),
"change_pwd_hint": self.is_default_cred(),
}
)
.__dict__
)
async def get_start_time(self):
return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__

View File

@@ -12,7 +12,10 @@ class StaticFileRoute(Route):
"/logs",
"/extension",
"/dashboard/default",
"/project-atri",
"/alkaid",
"/alkaid/knowledge-base",
"/alkaid/long-term-memory",
"/alkaid/other",
"/console",
"/chat",
"/settings",

View File

@@ -91,7 +91,7 @@ class UpdateRoute(Route):
# pip 更新依赖
logger.info("更新依赖中...")
try:
pip_installer.install(requirements_path="requirements.txt")
await pip_installer.install(requirements_path="requirements.txt")
except Exception as e:
logger.error(f"更新依赖失败: {e}")
@@ -140,7 +140,7 @@ class UpdateRoute(Route):
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
pip_installer.install(package, mirror=mirror)
await pip_installer.install(package, mirror=mirror)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")

View File

@@ -15,6 +15,8 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.utils.io import get_local_ip_addresses
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
APP: Quart = None
class AstrBotDashboard:
def __init__(
@@ -27,6 +29,7 @@ class AstrBotDashboard:
self.config = core_lifecycle.astrbot_config
self.data_path = os.path.abspath(os.path.join(get_astrbot_data_path(), "dist"))
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
APP = self.app # noqa
self.app.config["MAX_CONTENT_LENGTH"] = (
128 * 1024 * 1024
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
@@ -51,12 +54,29 @@ class AstrBotDashboard:
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
self.file_route = FileRoute(self.context)
self.app.add_url_rule(
"/api/plug/<path:subpath>",
view_func=self.srv_plug_route,
methods=["GET", "POST"],
)
self.shutdown_event = shutdown_event
async def srv_plug_route(self, subpath, *args, **kwargs):
"""
插件路由
"""
registered_web_apis = self.core_lifecycle.star_context.registered_web_apis
for api in registered_web_apis:
route, view_handler, methods, _ = api
if route == f"/{subpath}" and request.method in methods:
return await view_handler(*args, **kwargs)
return jsonify(Response().error("未找到该路由").__dict__)
async def auth_middleware(self):
if not request.path.startswith("/api"):
return
allowed_endpoints = ["/api/auth/login", "/api/chat/get_file", "/api/file"]
allowed_endpoints = ["/api/auth/login", "/api/file"]
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
return
# claim jwt

7
changelogs/v3.5.11.md Normal file
View File

@@ -0,0 +1,7 @@
# What's Changed
1. 新增:火山引擎 TTS
2. 修复:修复了 WeChatPadPro 在重新登录时为新设备的问题
2. ‼️修复:微信公众号(个人认证或者未认证)的情况下能接收但无法回复消息的问题
3. 修复Minimax TTS 相关问题
4. 优化:登录界面侧边栏、关于页面样式,修复如果此前已经登录但未自行跳转的问题

18
changelogs/v3.5.12.md Normal file
View File

@@ -0,0 +1,18 @@
# What's Changed
1. 新增:支持 MCP 的 Streamable HTTP 传输方式。详见 [#1637](https://github.com/Soulter/AstrBot/issues/1637)
2. 新增:支持 MCP 的 SSE 传输方式的自定义请求头。详见 [#1659](https://github.com/Soulter/AstrBot/issues/1659)
3. 优化:将 /llm 和 /model 和 /provider 指令设置为管理员指令
4. 修复:修复插件的 priority 部分失效的问题
5. 修复:修复 QQ 下合并转发消息内无法发送文件等问题,尽可能修复了各种文件、语音、视频、图片无法发送的问题
6. 优化Telegram 支持长消息分段发送,优化消息编辑的逻辑
7. 优化WebUI 强制默认修改密码
8. 优化:移除了 vpet
9. 新增:插件接口:支持动态路由注册
10. 优化CLI 模式下的插件下载
11. 新增WeChatPadPro 对接获取联系人接口
12. 新增T2I、语音、视频支持文件服务
13. 优化:硅基流动下某些工具调用返回的 argument 格式适配
14. 优化:在使用 /llm 指令关闭后重启 AstrBot 后,模型提供商未被加载
15. 新增:新增基于 FAISS + SQLite 的向量存储接口
16. 新增Alkaid Page

9
changelogs/v3.5.13.md Normal file
View File

@@ -0,0 +1,9 @@
# What's Changed
1. 新增WebUI 支持暗夜模式。
2. 修复:修复 WebUI Chat 接口的未授权访问安全漏洞、插件 README 可能存在的 XSS 注入漏洞。
3. 优化:优化 Vec DB 在 indexing 过程时的数据库事务处理。
4. 修复WebUI 下,插件市场的推荐卡片无法点击帮助文档的问题。
5. 新增:知识库。
6. 新增WebUI 提供商测试功能,一键检测可用性。
7. 新增WebUI 提供商分类功能,按能力分类提供商。

11
changelogs/v3.5.14.md Normal file
View File

@@ -0,0 +1,11 @@
# What's Changed
1. 优化:强化了 WebUI 安全性
2. 修复:测试文本生成提供商时可能出现的误报
3. 修复刷新知识库页面时出现404
4. 新增WeChatPadPro 支持获取引用、语音收发、视频等消息段
5. 优化WebUI 账户修改页面的设计逻辑
6. 优化:插件更新后自动刷新插件列表
7. 新增:支持下载插件的指定分支
8. 修复WeChatPadPro 群聊模式下 @ 不回复等问题
9. 其他更新、优化及修复

13
changelogs/v3.5.15.md Normal file
View File

@@ -0,0 +1,13 @@
# What's Changed
1. 修复:如果设置了 GitHub 加速地址,更新插件会报错
2. 修复:部分场景下,`只@触发等待` 配置项功能无效的问题
3. 新增:增加 `只@触发等待时是否回复` 配置项
4. 新增:**支持模型提供商使用时会话隔离(需要手动开启配置项:提供商会话隔离)**
5. 新增Google Gemini 提供商支持 URL 上下文功能
6. 新增:优化 WebChat 的 UI 显示WebChat 支持修改标题和自动生成标题,支持 WebChatBox
7. 新增:支持可配置是否忽略 @ 全体成员
8. 优化WebUI 顶栏移动端显示
9. 优化:插件/AstrBot 配置项完整性检查的同时也保证**配置项相对顺序一致性**
10. 优化perf: 分段回复时,仅在输出的第一句话带上回复/引用
11. 修复: Windows 下部署项目时可能出现的 UnicodeDecodeError。

View File

@@ -20,6 +20,7 @@
"axios": "^1.6.2",
"axios-mock-adapter": "^1.22.0",
"chance": "1.1.11",
"d3": "^7.9.0",
"date-fns": "2.30.0",
"highlight.js": "^11.11.1",
"js-md5": "^0.8.3",

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -340,12 +340,12 @@ export default {
.config-title {
font-weight: 600;
font-size: 1rem;
color: var(--v-primary-darken1);
color: var(--v-theme-primaryText);
}
.config-hint {
font-size: 0.75rem;
color: rgba(0, 0, 0, 0.6);
color: var(--v-theme-secondaryText);
margin-top: 2px;
}
@@ -400,12 +400,12 @@ export default {
.property-name {
font-size: 0.875rem;
font-weight: 600;
color: rgba(0, 0, 0, 0.87);
color: var(--v-theme-primaryText);
}
.property-hint {
font-size: 0.75rem;
color: rgba(0, 0, 0, 0.6);
color: var(--v-theme-secondaryText);
margin-top: 2px;
}

View File

@@ -5,7 +5,7 @@ import { useCommonStore } from '@/stores/common';
<template>
<div>
<!-- 添加筛选级别控件 -->
<div class="filter-controls mb-2">
<div class="filter-controls mb-2" v-if="showLevelBtns">
<v-chip-group v-model="selectedLevels" column multiple>
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter
:text-color="level === 'DEBUG' || level === 'INFO' ? 'black' : 'white'">
@@ -52,6 +52,10 @@ export default {
historyNum: {
type: String,
default: -1
},
showLevelBtns: {
type: Boolean,
default: true
}
},
watch: {

View File

@@ -1,5 +1,6 @@
<script setup lang="ts">
import { ref, computed, inject } from 'vue';
import {useCustomizerStore} from "@/stores/customizer";
const props = defineProps({
extension: {
@@ -75,7 +76,9 @@ const viewReadme = () => {
<template>
<v-card class="mx-auto d-flex flex-column" :elevation="highlight ? 0 : 1"
:style="{ height: $vuetify.display.xs ? '250px' : '220px', backgroundColor: highlight ? '#FAF0DB' : '#ffffff', color: highlight ? '#000' : '#000000' }">
:style="{ height: $vuetify.display.xs ? '250px' : '220px',
backgroundColor: useCustomizerStore().uiTheme==='PurpleTheme' ? marketMode ? '#f8f0dd' : '#ffffff' : '#282833',
color: useCustomizerStore().uiTheme==='PurpleTheme' ? '#000000dd' : '#ffffff'}">
<v-card-text style="padding: 16px; padding-bottom: 0px; display: flex; justify-content: space-between;">
<div class="flex-grow-1">
@@ -128,7 +131,7 @@ const viewReadme = () => {
</div>
</v-card-text>
<v-card-actions style="padding: 0px; margin-top: auto;">
<v-card-actions style="margin-left: 0px; gap: 2px;">
<v-btn color="teal-accent-4" text="查看文档" variant="text" @click="viewReadme"></v-btn>
<v-btn v-if="!marketMode" color="teal-accent-4" text="操作" variant="text" @click="reveal = true"></v-btn>
<v-btn v-if="marketMode && !extension?.installed" color="teal-accent-4" text="安装" variant="text"

View File

@@ -104,11 +104,11 @@ export default {
<style scoped>
.list-config-item {
border: 1px solid #e0e0e0;
border: 1px solid var(--v-theme-border);
padding: 16px;
margin-bottom: 8px;
border-radius: 10px;
background-color: #ffffff;
background-color: var(--v-theme-background);
}
.v-list-item {

View File

@@ -0,0 +1,78 @@
<template>
<div class="logo-container">
<div class="logo-content">
<div class="logo-image">
<img width="110" src="@/assets/images/astrbot_logo_mini.webp" alt="AstrBot Logo">
</div>
<div class="logo-text">
<h2 class="text-secondary">{{ title }}</h2>
<!-- 父子组件传递css变量可能会出错暂时使用十六进制颜色值 -->
<h4 :style="{color: useCustomizerStore().uiTheme === 'PurpleTheme' ? '#000000aa' : '#ffffffcc'}"
class="hint-text">{{ subtitle }}</h4>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { useCustomizerStore } from "@/stores/customizer";
const props = withDefaults(defineProps<{
title?: string;
subtitle?: string;
}>(), {
title: 'AstrBot 仪表盘',
subtitle: '欢迎使用'
})
</script>
<style scoped>
.logo-container {
display: flex;
justify-content: center;
align-items: center;
width: 100%;
margin-bottom: 10px;
}
.logo-content {
display: flex;
align-items: center;
gap: 20px;
padding: 10px;
}
.logo-image {
display: flex;
justify-content: center;
align-items: center;
}
.logo-image img {
transition: transform 0.3s ease;
}
.logo-image img:hover {
transform: scale(1.05);
}
.logo-text {
display: flex;
flex-direction: column;
align-items: flex-start;
}
.logo-text h2 {
margin: 0;
font-size: 1.8rem;
font-weight: 600;
letter-spacing: 0.5px;
}
.logo-text h4 {
margin: 4px 0 0 0;
font-size: 1rem;
font-weight: 400;
letter-spacing: 0.3px;
}
</style>

View File

@@ -3,14 +3,25 @@ export type ConfigProps = {
Customizer_drawer: boolean;
mini_sidebar: boolean;
fontTheme: string;
uiTheme: string;
inputBg: boolean;
};
function checkUITheme() {
/* 检查localStorage有无记忆的主题选项如有则使用否则使用默认值 */
const theme = localStorage.getItem("uiTheme");
if (!theme || !(['PurpleTheme', 'PurpleThemeDark'].includes(theme))) {
localStorage.setItem("uiTheme", "PurpleTheme"); // todo: 这部分可以根据vuetify.ts的默认主题动态调整
return 'PurpleTheme';
} else return theme;
}
const config: ConfigProps = {
Sidebar_drawer: true,
Customizer_drawer: false,
mini_sidebar: false,
fontTheme: 'Roboto',
uiTheme: checkUITheme(),
inputBg: false
};

View File

@@ -2,14 +2,13 @@
import { RouterView } from 'vue-router';
import VerticalSidebarVue from './vertical-sidebar/VerticalSidebar.vue';
import VerticalHeaderVue from './vertical-header/VerticalHeader.vue';
import { useCustomizerStore } from '../../stores/customizer';
import { useCustomizerStore } from '@/stores/customizer';
const customizer = useCustomizerStore();
</script>
<template>
<v-locale-provider>
<v-app
theme="PurpleTheme"
<v-app :theme="useCustomizerStore().uiTheme"
:class="[customizer.fontTheme, customizer.mini_sidebar ? 'mini-sidebar' : '', customizer.inputBg ? 'inputWithbg' : '']"
>
<VerticalHeaderVue />

View File

@@ -1,16 +1,18 @@
<script setup lang="ts">
import { ref } from 'vue';
import { useCustomizerStore } from '../../../stores/customizer';
import {ref, computed} from 'vue';
import {useCustomizerStore} from '@/stores/customizer';
import axios from 'axios';
import { md5 } from 'js-md5';
import { useAuthStore } from '@/stores/auth';
import { useCommonStore } from '@/stores/common';
import { marked } from 'marked';
import Logo from '@/components/shared/Logo.vue';
import {md5} from 'js-md5';
import {useAuthStore} from '@/stores/auth';
import {useCommonStore} from '@/stores/common';
import {marked} from 'marked';
const customizer = useCustomizerStore();
let dialog = ref(false);
let accountWarning = ref(false)
let updateStatusDialog = ref(false);
const username = localStorage.getItem('user');
let password = ref('');
let newPassword = ref('');
let newUsername = ref('');
@@ -23,26 +25,52 @@ let dashboardHasNewVersion = ref(false);
let dashboardCurrentVersion = ref('');
let version = ref('');
let releases = ref([]);
let devCommits = ref([]); // 新增的 ref
let devCommits = ref([]);
let installLoading = ref(false);
let tab = ref(0);
let releasesHeader = [
{ title: '标签', key: 'tag_name' },
{ title: '发布时间', key: 'published_at' },
{ title: '内容', key: 'body' },
{ title: '源码地址', key: 'zipball_url' },
{ title: '操作', key: 'switch' }
{title: '标签', key: 'tag_name'},
{title: '发布时间', key: 'published_at'},
{title: '内容', key: 'body'},
{title: '源码地址', key: 'zipball_url'},
{title: '操作', key: 'switch'}
];
// Form validation
const formValid = ref(true);
const passwordRules = [
(v: string) => !!v || '请输入密码',
(v: string) => v.length >= 8 || '密码长度至少 8 位'
];
const usernameRules = [
(v: string) => !v || v.length >= 3 || '用户名长度至少3位'
];
// 显示密码相关
const showPassword = ref(false);
const showNewPassword = ref(false);
// 账户修改状态
const accountEditStatus = ref({
loading: false,
success: false,
error: false,
message: ''
});
const open = (link: string) => {
window.open(link, '_blank');
};
// 账户修改
function accountEdit() {
accountEditStatus.value.loading = true;
accountEditStatus.value.error = false;
accountEditStatus.value.success = false;
// md5加密
// @ts-ignore
if (password.value != '') {
@@ -54,71 +82,92 @@ function accountEdit() {
axios.post('/api/auth/account/edit', {
password: password.value,
new_password: newPassword.value,
new_username: newUsername.value
new_username: newUsername.value ? newUsername.value : username
})
.then((res) => {
if (res.data.status == 'error') {
status.value = res.data.message;
.then((res) => {
if (res.data.status == 'error') {
accountEditStatus.value.error = true;
accountEditStatus.value.message = res.data.message;
password.value = '';
newPassword.value = '';
return;
}
accountEditStatus.value.success = true;
accountEditStatus.value.message = res.data.message;
setTimeout(() => {
dialog.value = !dialog.value;
const authStore = useAuthStore();
authStore.logout();
}, 2000);
})
.catch((err) => {
console.log(err);
accountEditStatus.value.error = true;
accountEditStatus.value.message = typeof err === 'string' ? err : '修改失败,请重试';
password.value = '';
newPassword.value = '';
return;
}
dialog.value = !dialog.value;
status.value = res.data.message;
setTimeout(() => {
const authStore = useAuthStore();
authStore.logout();
}, 1000);
})
.catch((err) => {
console.log(err);
status.value = err
password.value = '';
newPassword.value = '';
});
})
.finally(() => {
accountEditStatus.value.loading = false;
});
}
function getVersion() {
axios.get('/api/stat/version')
.then((res) => {
botCurrVersion.value = "v" + res.data.data.version;
dashboardCurrentVersion.value = res.data.data?.dashboard_version;
let change_pwd_hint = res.data.data?.change_pwd_hint;
if (change_pwd_hint) {
dialog.value = true;
accountWarning.value = true;
localStorage.setItem('change_pwd_hint', 'true');
} else {
localStorage.removeItem('change_pwd_hint');
}
})
.catch((err) => {
console.log(err);
});
}
function checkUpdate() {
updateStatus.value = '正在检查更新...';
axios.get('/api/update/check')
.then((res) => {
hasNewVersion.value = res.data.data.has_new_version;
.then((res) => {
hasNewVersion.value = res.data.data.has_new_version;
if (res.data.data.has_new_version) {
releaseMessage.value = res.data.message;
updateStatus.value = '有新版本!';
} else {
updateStatus.value = res.data.message;
}
botCurrVersion.value = res.data.data.version;
dashboardCurrentVersion.value = res.data.data.dashboard_version;
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
})
.catch((err) => {
if (err.response.status == 401) {
console.log("401");
const authStore = useAuthStore();
authStore.logout();
return;
}
console.log(err);
updateStatus.value = err
});
if (res.data.data.has_new_version) {
releaseMessage.value = res.data.message;
updateStatus.value = '有新版本!';
} else {
updateStatus.value = res.data.message;
}
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
})
.catch((err) => {
if (err.response.status == 401) {
console.log("401");
const authStore = useAuthStore();
authStore.logout();
return;
}
console.log(err);
updateStatus.value = err
});
}
function getReleases() {
axios.get('/api/update/releases')
.then((res) => {
// releases.value = res.data.data;
// 更新 published_at 的时间为本地时间
releases.value = res.data.data.map((item: any) => {
item.published_at = new Date(item.published_at).toLocaleString();
return item;
.then((res) => {
releases.value = res.data.data.map((item: any) => {
item.published_at = new Date(item.published_at).toLocaleString();
return item;
})
})
})
.catch((err) => {
console.log(err);
});
.catch((err) => {
console.log(err);
});
}
function getDevCommits() {
@@ -128,17 +177,17 @@ function getDevCommits() {
'Referer': 'https://api.github.com'
}
})
.then(response => response.json())
.then(data => {
devCommits.value = data.map((commit: any) => ({
sha: commit.sha,
date: new Date(commit.commit.author.date).toLocaleString(),
message: commit.commit.message
}));
})
.catch(err => {
console.log(err);
});
.then(response => response.json())
.then(data => {
devCommits.value = data.map((commit: any) => ({
sha: commit.sha,
date: new Date(commit.commit.author.date).toLocaleString(),
message: commit.commit.message
}));
})
.catch(err => {
console.log(err);
});
}
function switchVersion(version: string) {
@@ -148,88 +197,111 @@ function switchVersion(version: string) {
version: version,
proxy: localStorage.getItem('selectedGitHubProxy') || ''
})
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
}).finally(() => {
installLoading.value = false;
});
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
}).finally(() => {
installLoading.value = false;
});
}
function updateDashboard() {
updateStatus.value = '正在更新...';
axios.post('/api/update/dashboard')
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
});
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
});
}
function toggleDarkMode() {
customizer.SET_UI_THEME(customizer.uiTheme === 'PurpleThemeDark' ? 'PurpleTheme' : 'PurpleThemeDark');
}
getVersion();
checkUpdate();
const commonStore = useCommonStore();
commonStore.createEventSource(); // log
commonStore.getStartTime();
if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('change_pwd_hint') == 'true') {
dialog.value = true;
accountWarning.value = true;
localStorage.removeItem('change_pwd_hint');
}
</script>
<template>
<v-app-bar elevation="0" height="55">
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<v-btn class="hidden-lg-and-up text-secondary ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
<v-btn v-else style="margin-left: 22px; color: var(--v-theme-primaryText); background-color: var(--v-theme-secondary)" class="hidden-md-and-down" icon rounded="sm"
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<v-btn v-if="useCustomizerStore().uiTheme==='PurpleTheme'" class="hidden-lg-and-up ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<v-btn v-else class="hidden-lg-and-up ms-3" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_SIDEBAR_DRAWER" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<span style="margin-left: 16px; font-size: 24px; font-weight: 1000;">Astr<span
style="font-weight: normal;">Bot</span></span>
<div class="logo-container" :class="{'mobile-logo': $vuetify.display.xs}">
<span class="logo-text">Astr<span class="logo-text-light">Bot</span></span>
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
</div>
<v-spacer />
<v-spacer/>
<div class="mr-4">
<!-- 版本提示信息 - 在手机上隐藏 -->
<div class="mr-4 hidden-xs">
<small v-if="hasNewVersion">
有新版本
AstrBot 有新版本
</small>
<small v-else-if="dashboardHasNewVersion">
WebUI 有新版本
</small>
</div>
<!-- 主题切换按钮 -->
<v-btn size="small" @click="toggleDarkMode();" class="action-btn"
color="var(--v-theme-surface)" variant="flat" rounded="sm">
<v-icon v-if="useCustomizerStore().uiTheme === 'PurpleThemeDark'">mdi-weather-night</v-icon>
<v-icon v-else>mdi-white-balance-sunny</v-icon>
</v-btn>
<v-dialog v-model="updateStatusDialog" width="1000">
<!-- 更新对话框 -->
<v-dialog v-model="updateStatusDialog" :width="$vuetify.display.smAndDown ? '100%' : '1000'" :fullscreen="$vuetify.display.xs">
<template v-slot:activator="{ props }">
<v-btn @click="checkUpdate(); getReleases(); getDevCommits();" class="text-primary mr-4" color="lightprimary"
variant="flat" rounded="sm" v-bind="props">
更新 🔄
<v-btn size="small" @click="checkUpdate(); getReleases(); getDevCommits();" class="action-btn"
color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
<v-icon class="hidden-sm-and-up">mdi-update</v-icon>
<span class="hidden-xs">更新</span>
</v-btn>
</template>
<v-card>
<v-card-title>
<v-card-title class="mobile-card-title">
<span class="text-h5">更新 AstrBot</span>
<v-btn v-if="$vuetify.display.xs" icon @click="updateStatusDialog = false">
<v-icon>mdi-close</v-icon>
</v-btn>
</v-card-title>
<v-card-text>
<v-container>
@@ -240,16 +312,16 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
<small style="margin-left: 4px;">{{ updateStatus }}</small>
</div>
<div
style="background-color: #646cff24; padding: 16px; border-radius: 10px; font-size: 14px; max-height: 400px; overflow-y: auto;"
v-html="marked(releaseMessage)" class="markdown-content">
<div v-if="releaseMessage"
style="background-color: #646cff24; padding: 16px; border-radius: 10px; font-size: 14px; max-height: 400px; overflow-y: auto;"
v-html="marked(releaseMessage)" class="markdown-content">
</div>
<div class="mb-4 mt-4">
<small>💡 TIP: 跳到旧版本或者切换到某个版本不会重新下载管理面板文件这可能会造成部分数据显示错误您可在 <a
href="https://github.com/Soulter/AstrBot/releases">此处</a>
找到对应的面板文件 dist.zip解压后替换 data/dist 文件夹即可当然前端源代码在 dashboard 目录下你也可以自己使用 npm install npm build
找到对应的面板文件 dist.zip解压后替换 data/dist 文件夹即可当然前端源代码在 dashboard 目录下你也可以自己使用
npm install npm build
构建</small>
</div>
@@ -262,12 +334,13 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
<!-- 发行版 -->
<v-tabs-window-item key="0" v-show="tab == 0">
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
:disabled="!hasNewVersion">
:disabled="!hasNewVersion">
更新到最新版本
</v-btn>
<div class="mb-4">
<small>`更新到最新版本` 按钮会同时尝试更新机器人主程序和管理面板如果您正在使用 Docker 部署也可以重新拉取镜像或者使用 <a
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取</small>
<small>`更新到最新版本` 按钮会同时尝试更新机器人主程序和管理面板如果您正在使用 Docker
部署也可以重新拉取镜像或者使用 <a
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取</small>
</div>
<v-data-table :headers="releasesHeader" :items="releases" item-key="name">
@@ -290,8 +363,8 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
<v-tabs-window-item key="1" v-show="tab == 1">
<div style="margin-top: 16px;">
<v-data-table
:headers="[{ title: 'SHA', key: 'sha' }, { title: '日期', key: 'date' }, { title: '信息', key: 'message' }, { title: '操作', key: 'switch' }]"
:items="devCommits" item-key="sha">
:headers="[{ title: 'SHA', key: 'sha' }, { title: '日期', key: 'date' }, { title: '信息', key: 'message' }, { title: '操作', key: 'switch' }]"
:items="devCommits" item-key="sha">
<template v-slot:item.switch="{ item }: { item: { sha: string } }">
<v-btn @click="switchVersion(item.sha)" rounded="xl" variant="plain" color="primary">
切换
@@ -306,12 +379,13 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
<h3 class="mb-4">手动输入版本号或 Commit SHA</h3>
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
variant="outlined"></v-text-field>
variant="outlined"></v-text-field>
<div class="mb-4">
<small> v3.3.16 (不带 SHA) 42e5ec5d80b93b6bfe8b566754d45ffac4c3fe0b</small>
<br>
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>查看 master 分支提交记录点击右边的 copy
即可复制</small></a>
<a href="https://github.com/Soulter/AstrBot/commits/master"><small>查看 master 分支提交记录点击右边的
copy
即可复制</small></a>
</div>
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
确定切换
@@ -336,7 +410,7 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
</div>
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()"
:disabled="!dashboardHasNewVersion">
:disabled="!dashboardHasNewVersion">
下载并更新
</v-btn>
</div>
@@ -351,46 +425,119 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
</v-card>
</v-dialog>
<v-dialog v-model="dialog" persistent width="700">
<!-- 账户对话框 -->
<v-dialog v-model="dialog" persistent :max-width="$vuetify.display.xs ? '90%' : '500'">
<template v-slot:activator="{ props }">
<v-btn class="text-primary mr-4" color="lightprimary" variant="flat" rounded="sm" v-bind="props">
账户 📰
<v-btn size="small" class="action-btn mr-4" color="var(--v-theme-surface)" variant="flat" rounded="sm" v-bind="props">
<v-icon>mdi-account</v-icon>
<span class="hidden-xs ml-1">账户</span>
</v-btn>
</template>
<v-card>
<v-card-title>
<span class="text-h5">账户</span>
</v-card-title>
<v-card-text>
<v-container>
<v-row>
<v-col cols="12">
<v-card class="account-dialog">
<v-card-text class="py-6">
<div class="d-flex flex-column align-center mb-6">
<logo title="AstrBot 仪表盘" subtitle="修改账户"></logo>
</div>
<v-alert
v-if="accountWarning"
type="warning"
variant="tonal"
border="start"
class="mb-4"
>
<strong>安全提醒:</strong> 请修改默认密码以确保账户安全
</v-alert>
<v-alert v-if="accountWarning" color="warning" style="margin-bottom: 16px;">
<div>为了安全请尽快修改默认密码</div>
</v-alert>
<v-alert
v-if="accountEditStatus.success"
type="success"
variant="tonal"
border="start"
class="mb-4"
>
{{ accountEditStatus.message }}
</v-alert>
<v-text-field label="原密码*" type="password" v-model="password" required
variant="outlined"></v-text-field>
<v-alert
v-if="accountEditStatus.error"
type="error"
variant="tonal"
border="start"
class="mb-4"
>
{{ accountEditStatus.message }}
</v-alert>
<v-text-field label="新用户名" v-model="newUsername" required variant="outlined"></v-text-field>
<v-form v-model="formValid" @submit.prevent="accountEdit">
<v-text-field
v-model="password"
:append-inner-icon="showPassword ? 'mdi-eye-off' : 'mdi-eye'"
:type="showPassword ? 'text' : 'password'"
label="当前密码"
variant="outlined"
required
clearable
@click:append-inner="showPassword = !showPassword"
prepend-inner-icon="mdi-lock-outline"
hide-details="auto"
class="mb-4"
></v-text-field>
<v-text-field label="新密码" type="password" v-model="newPassword" required
variant="outlined"></v-text-field>
</v-col>
</v-row>
</v-container>
<small>默认用户名和密码是 astrbot</small>
<br>
<small>{{ status }}</small>
<v-text-field
v-model="newPassword"
:append-inner-icon="showNewPassword ? 'mdi-eye-off' : 'mdi-eye'"
:type="showNewPassword ? 'text' : 'password'"
:rules="passwordRules"
label="新密码"
variant="outlined"
required
clearable
@click:append-inner="showNewPassword = !showNewPassword"
prepend-inner-icon="mdi-lock-plus-outline"
hint="密码长度至少 8 位"
persistent-hint
class="mb-4"
></v-text-field>
<v-text-field
v-model="newUsername"
:rules="usernameRules"
label="新用户名 (可选)"
variant="outlined"
clearable
prepend-inner-icon="mdi-account-edit-outline"
hint="留空表示不修改用户名"
persistent-hint
class="mb-3"
></v-text-field>
</v-form>
<div class="text-caption text-medium-emphasis mt-2">
默认用户名和密码均为 astrbot
</div>
</v-card-text>
<v-card-actions>
<v-divider></v-divider>
<v-card-actions class="pa-4">
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="dialog = false">
关闭
<v-btn
v-if="!accountWarning"
variant="tonal"
color="secondary"
@click="dialog = false"
:disabled="accountEditStatus.loading"
>
取消
</v-btn>
<v-btn color="blue-darken-1" variant="text" @click="accountEdit">
提交
<v-btn
color="primary"
@click="accountEdit"
:loading="accountEditStatus.loading"
:disabled="!formValid"
prepend-icon="mdi-content-save"
>
保存修改
</v-btn>
</v-card-actions>
</v-card>
@@ -416,4 +563,91 @@ if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('cha
margin-top: 8px;
margin-bottom: 8px;
}
.account-dialog .v-card-text {
padding-top: 24px;
padding-bottom: 24px;
}
.account-dialog .v-alert {
margin-bottom: 20px;
}
.account-dialog .v-btn {
text-transform: none;
font-weight: 500;
border-radius: 8px;
}
.account-dialog .v-avatar {
transition: transform 0.3s ease;
}
.account-dialog .v-avatar:hover {
transform: scale(1.05);
}
/* 响应式布局样式 */
.logo-container {
margin-left: 16px;
display: flex;
align-items: center;
gap: 8px;
}
.mobile-logo {
margin-left: 8px;
gap: 4px;
}
.logo-text {
font-size: 24px;
font-weight: 1000;
}
.logo-text-light {
font-weight: normal;
}
.version-text {
font-size: 12px;
color: var(--v-theme-secondaryText);
}
.action-btn {
margin-right: 6px;
}
/* 移动端对话框标题样式 */
.mobile-card-title {
display: flex;
justify-content: space-between;
align-items: center;
}
/* 移动端样式优化 */
@media (max-width: 600px) {
.logo-text {
font-size: 20px;
}
.action-btn {
margin-right: 4px;
min-width: 32px !important;
width: 32px;
}
.v-card-title {
padding: 12px 16px;
}
.v-card-text {
padding: 16px;
}
.v-tabs .v-tab {
padding: 0 10px;
font-size: 0.9rem;
}
}
</style>

View File

@@ -9,9 +9,6 @@ const customizer = useCustomizerStore();
const sidebarMenu = shallowRef(sidebarItems);
const showIframe = ref(false);
const version = ref("");
const buildVer = ref("");
const hasWebUIUpdate = ref(false);
// 默认桌面端 iframe 样式
const iframeStyle = ref({
@@ -68,9 +65,10 @@ function toggleIframe() {
showIframe.value = !showIframe.value;
}
function openIframeLink() {
function openIframeLink(url) {
if (typeof window !== 'undefined') {
window.open("https://astrbot.app", "_blank");
let url_ = url || "https://astrbot.app";
window.open(url_, "_blank");
}
}
@@ -149,25 +147,6 @@ function endDrag() {
document.removeEventListener('touchend', onTouchEnd);
}
// 获取版本和更新信息
onMounted(() => {
axios.get('/api/stat/version')
.then((res) => {
version.value = "v" + res.data.data.version;
})
.catch((err) => {
console.log(err);
});
axios.get('/api/update/check?type=dashboard')
.then((res) => {
hasWebUIUpdate.value = res.data.data.has_new_version;
buildVer.value = res.data.data.current_version;
})
.catch((err) => {
console.log(err);
});
});
</script>
<template>
@@ -186,27 +165,23 @@ onMounted(() => {
<NavItem :item="item" class="leftPadding" />
</template>
</v-list>
<div class="text-center">
<v-chip color="inputBorder" size="small"> {{ version }} </v-chip>
</div>
<div style="position: absolute; bottom: 32px; width: 100%; font-size: 13px;" class="text-center">
<v-list-item v-if="!customizer.mini_sidebar" @click="toggleIframe">
<v-btn variant="plain" size="small">
🤔 点击此处 查看/关闭 悬浮文档
</v-btn>
</v-list-item>
<small style="display: block;" v-if="buildVer">WebUI 版本: {{ buildVer }}</small>
<small style="display: block;" v-else>构建: embedded</small>
<v-tooltip text="使用 /dashboard_update 指令更新管理面板">
<template v-slot:activator="{ props }">
<small v-bind="props" v-if="hasWebUIUpdate" style="display: block; margin-top: 4px;">面板有更新</small>
</template>
</v-tooltip>
<small style="display: block; margin-top: 8px;">AGPL-3.0</small>
<div style="position: absolute; bottom: 16px; width: 100%; font-size: 13px;" class="text-center">
<v-btn style="margin-bottom: 8px;" size="small" variant="primary" v-if="!customizer.mini_sidebar" to="/settings">
🔧 设置
</v-btn>
<br/>
<v-btn style="margin-bottom: 8px;" size="small" variant="plain" v-if="!customizer.mini_sidebar" @click="toggleIframe">
官方文档
</v-btn>
<br/>
<v-btn style="margin-bottom: 8px;" size="small" variant="plain" v-if="!customizer.mini_sidebar" @click="openIframeLink('https://github.com/AstrBotDevs/AstrBot')">
GitHub
</v-btn>
<br/>
</div>
</v-navigation-drawer>
<!-- 优化后的悬浮 iframe -->
<div
v-if="showIframe"
id="draggable-iframe"

View File

@@ -66,9 +66,9 @@ const sidebarItem: menu[] = [
to: '/console'
},
{
title: '设置',
icon: 'mdi-wrench',
to: '/settings'
title: 'Alkaid',
icon: 'mdi-test-tube',
to: '/alkaid'
},
{
title: '关于',

View File

@@ -3,6 +3,7 @@ import '@mdi/font/css/materialdesignicons.css';
import * as components from 'vuetify/components';
import * as directives from 'vuetify/directives';
import { PurpleTheme } from '@/theme/LightTheme';
import { PurpleThemeDark } from "@/theme/DarkTheme";
export default createVuetify({
components,
@@ -11,7 +12,8 @@ export default createVuetify({
theme: {
defaultTheme: 'PurpleTheme',
themes: {
PurpleTheme
PurpleTheme,
PurpleThemeDark
}
},
defaults: {

View File

@@ -0,0 +1,21 @@
const ChatBoxRoutes = {
path: '/chatbox',
component: () => import('@/layouts/blank/BlankLayout.vue'),
children: [
{
name: 'ChatBox',
path: '/chatbox',
component: () => import('@/views/ChatBoxPage.vue'),
children: [
{
path: ':conversationId',
name: 'ChatBoxDetail',
component: () => import('@/views/ChatBoxPage.vue'),
props: true
}
]
}
]
};
export default ChatBoxRoutes;

View File

@@ -57,14 +57,39 @@ const MainRoutes = {
component: () => import('@/views/ConsolePage.vue')
},
{
name: 'Project ATRI',
path: '/project-atri',
component: () => import('@/views/ATRIProject.vue')
name: 'Alkaid',
path: '/alkaid',
component: () => import('@/views/AlkaidPage.vue'),
children: [
{
path: 'knowledge-base',
name: 'KnowledgeBase',
component: () => import('@/views/alkaid/KnowledgeBase.vue')
},
{
path: 'long-term-memory',
name: 'LongTermMemory',
component: () => import('@/views/alkaid/LongTermMemory.vue')
},
{
path: 'other',
name: 'OtherFeatures',
component: () => import('@/views/alkaid/Other.vue')
}
]
},
{
name: 'Chat',
path: '/chat',
component: () => import('@/views/ChatPage.vue')
component: () => import('@/views/ChatPage.vue'),
children: [
{
path: ':conversationId',
name: 'ChatDetail',
component: () => import('@/views/ChatPage.vue'),
props: true
}
]
},
{
name: 'Settings',

View File

@@ -1,13 +1,15 @@
import { createRouter, createWebHistory } from 'vue-router';
import MainRoutes from './MainRoutes';
import AuthRoutes from './AuthRoutes';
import ChatBoxRoutes from './ChatBoxRoutes';
import { useAuthStore } from '@/stores/auth';
export const router = createRouter({
history: createWebHistory(import.meta.env.BASE_URL),
routes: [
MainRoutes,
AuthRoutes
AuthRoutes,
ChatBoxRoutes
]
});
@@ -24,6 +26,11 @@ router.beforeEach(async (to, from, next) => {
const authRequired = !publicPages.includes(to.path);
const auth: AuthStore = useAuthStore();
// 如果用户已登录且试图访问登录页面,则重定向到首页或之前尝试访问的页面
if (to.path === '/auth/login' && auth.has_token()) {
return next(auth.returnUrl || '/');
}
if (to.matched.some((record) => record.meta.requiresAuth)) {
if (authRequired && !auth.has_token()) {
auth.returnUrl = to.fullPath;

View File

@@ -6,7 +6,7 @@
.listitem {
height: calc(100vh - 100px);
.v-list {
color: rgb(var(--v-theme-lightText));
color: rgb(var(--v-theme-secondaryText));
}
.v-list-group__items .v-list-item,
.v-list-item {

View File

@@ -32,7 +32,7 @@ export const useAuthStore = defineStore({
},
logout() {
this.username = '';
localStorage.removeItem('username');
localStorage.removeItem('user');
localStorage.removeItem('token');
router.push('/auth/login');
},

View File

@@ -8,6 +8,7 @@ export const useCustomizerStore = defineStore({
Customizer_drawer: config.Customizer_drawer,
mini_sidebar: config.mini_sidebar,
fontTheme: "Poppins",
uiTheme: config.uiTheme,
inputBg: config.inputBg
}),
@@ -21,6 +22,10 @@ export const useCustomizerStore = defineStore({
},
SET_FONT(payload: string) {
this.fontTheme = payload;
}
},
SET_UI_THEME(payload: string) {
this.uiTheme = payload;
localStorage.setItem("uiTheme", payload);
},
}
});

View File

@@ -0,0 +1,46 @@
import type { ThemeTypes } from '@/types/themeTypes/ThemeType';
const PurpleThemeDark: ThemeTypes = {
name: 'PurpleThemeDark',
dark: true,
variables: {
'border-color': '#1677ff',
'carousel-control-size': 10
},
colors: {
primary: '#1677ff',
secondary: '#722ed1',
info: '#03c9d7',
success: '#52c41a',
accent: '#FFAB91',
warning: '#faad14',
error: '#ff4d4f',
lightprimary: '#eef2f6',
lightsecondary: '#ede7f6',
lightsuccess: '#b9f6ca',
lighterror: '#f9d8d8',
lightwarning: '#fff8e1',
primaryText: '#ffffff',
secondaryText: '#ffffffcc',
darkprimary: '#1565c0',
darksecondary: '#4527a0',
borderLight: '#d0d0d0',
border: '#333333ee',
inputBorder: '#787878',
containerBg: '#1a1a1a',
surface: '#1f1f1f',
'on-surface-variant': '#000',
facebook: '#4267b2',
twitter: '#1da1f2',
linkedin: '#0e76a8',
gray100: '#cccccccc',
primary200: '#90caf9',
secondary200: '#b39ddb',
background: '#111111',
overlay: '#111111aa',
codeBg: '#282833',
code: '#ffffffdd'
}
};
export { PurpleThemeDark };

View File

@@ -20,11 +20,12 @@ const PurpleTheme: ThemeTypes = {
lightsuccess: '#b9f6ca',
lighterror: '#f9d8d8',
lightwarning: '#fff8e1',
darkText: '#212121',
lightText: '#616161',
primaryText: '#000000dd',
secondaryText: '#000000aa',
darkprimary: '#1565c0',
darksecondary: '#4527a0',
borderLight: '#d0d0d0',
border: '#d0d0d0',
inputBorder: '#787878',
containerBg: '#eef2f6',
surface: '#fff',
@@ -32,9 +33,13 @@ const PurpleTheme: ThemeTypes = {
facebook: '#4267b2',
twitter: '#1da1f2',
linkedin: '#0e76a8',
gray100: '#fafafa',
gray100: '#fafafacc',
primary200: '#90caf9',
secondary200: '#b39ddb'
secondary200: '#b39ddb',
background: '#f9fafcf4',
overlay: '#ffffffaa',
codeBg: '#f5f0ff',
code: '#673ab7'
}
};

View File

@@ -17,13 +17,15 @@ export type ThemeTypes = {
lightwarning?: string;
darkprimary?: string;
darksecondary?: string;
darkText?: string;
lightText?: string;
primaryText?: string;
secondaryText?: string;
borderLight?: string;
border?: string;
inputBorder?: string;
containerBg?: string;
surface?: string;
background?: string;
overlay?: string;
'on-surface-variant'?: string;
facebook?: string;
twitter?: string;
@@ -31,5 +33,7 @@ export type ThemeTypes = {
gray100?: string;
primary200?: string;
secondary200?: string;
codeBg?: string;
code?: string;
};
};

View File

@@ -1,87 +0,0 @@
<script setup>
</script>
<template>
<v-alert style="margin-bottom: 16px"
text="这是一个长期实验性功能,目标是实现更具人类机能的 LLM 对话。推荐使用 gpt-4o-mini 作为文本生成和视觉理解模型,成本很低。推荐使用 text-embedding-3-small 作为 Embedding 模型,成本忽略不计。"
title="💡实验性功能" type="info" variant="tonal">
</v-alert>
<v-card>
<v-card-text>
<v-container fluid>
<AstrBotConfig :metadata="project_atri_config_metadata" :iterable="project_atri_config?.project_atri"
metadataKey="project_atri">
</AstrBotConfig>
</v-container>
</v-card-text>
</v-card>
<v-btn icon="mdi-content-save" size="x-large" style="position: fixed; right: 52px; bottom: 52px;" color="darkprimary"
@click="updateConfig">
</v-btn>
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack">
{{ save_message }}
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<script>
import axios from 'axios';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
export default {
name: 'AtriProject',
components: {
AstrBotConfig,
WaitingForRestart
},
data() {
return {
project_atri_config: {},
fetched: false,
project_atri_config_metadata: {},
save_message_snack: false,
save_message: "",
save_message_success: "",
}
},
mounted() {
this.getConfig();
},
methods: {
getConfig() {
// 获取配置
axios.get('/api/config/get').then((res) => {
this.project_atri_config = res.data.data.config;
this.fetched = true
this.project_atri_config_metadata = res.data.data.metadata;
}).catch((err) => {
save_message = err;
save_message_snack = true;
save_message_success = "error";
});
},
updateConfig() {
if (!this.fetched) return;
axios.post('/api/config/astrbot/update', this.project_atri_config).then((res) => {
if (res.data.status === "ok") {
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
this.$refs.wfr.check();
} else {
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "error";
}
}).catch((err) => {
this.save_message = err;
this.save_message_snack = true;
this.save_message_success = "error";
});
},
},
}
</script>

View File

@@ -1,55 +1,95 @@
<template>
<v-card style="height: 100%;">
<v-card-text style="padding: 0; height: 100%; overflow-y: auto;">
<div
style="display: flex; justify-content: center; align-items: center; height: 100%; flex-direction: column;">
<div @click="selectedLogo = selectedLogo == 0 ? 1 : 0" style="height: 300px;">
<img v-if="selectedLogo == 0" width="300" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo"
class="fade-in">
<img v-if="selectedLogo == 1" width="300" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo"
class="fade-in">
</div>
<v-card style="height: 100%;" elevation="0" class="bg-surface">
<v-card-text style="padding: 0; height: 100%; overflow-y: hidden;">
<div class="about-wrapper">
<!-- Hero Section -->
<section class="hero-section">
<div class="logo-title-container">
<div @click="selectedLogo = selectedLogo == 0 ? 1 : 0" class="logo-container">
<img v-if="selectedLogo == 0" width="280" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo" class="fade-in">
<img v-if="selectedLogo == 1" width="280" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo" class="fade-in">
</div>
<div class="title-container">
<h1 class="text-h2 font-weight-bold">AstrBot</h1>
<p class="text-subtitle-1" style="color: var(--v-theme-secondaryText);">A project out of interests and loves </p>
<div class="action-buttons">
<v-btn @click="open('https://github.com/Soulter/AstrBot')"
color="primary" variant="elevated" prepend-icon="mdi-star">
Star 这个项目! 🌟
</v-btn>
<v-btn class="ml-4" @click="open('https://github.com/Soulter/AstrBot/issues')"
color="secondary" variant="elevated" prepend-icon="mdi-comment-question">
提交 Issue
</v-btn>
</div>
</div>
</div>
</section>
<h1 class="mt-8">AstrBot</h1>
<!-- Contributors Section -->
<section class="contributors-section">
<v-container>
<v-row justify="center" align="center">
<v-col cols="12" md="6" class="pr-md-8 contributors-info">
<h2 class="text-h4 font-weight-medium">贡献者</h2>
<p class="mb-4 text-body-1" style="color: var(--v-theme-secondaryText);">
本项目由众多开源社区成员共同维护感谢每一位贡献者的付出
</p>
<p class="text-body-1" style="color: var(--v-theme-secondaryText);">
<a href="https://github.com/Soulter/AstrBot/graphs/contributors" class="text-decoration-none custom-link">查看 AstrBot 贡献者</a>
</p>
</v-col>
<v-col cols="12" md="6">
<v-card variant="outlined" class="overflow-hidden" elevation="2">
<v-img v-if="useCustomizerStore().uiTheme==='PurpleThemeDark'"
alt="Active Contributors of Soulter/AstrBot"
src="https://next.ossinsight.io/widgets/official/compose-recent-active-contributors/thumbnail.png?repo_id=575865240&limit=365&image_size=auto&color_scheme=dark">
</v-img>
<v-img v-else
alt="Active Contributors of Soulter/AstrBot"
src="https://next.ossinsight.io/widgets/official/compose-recent-active-contributors/thumbnail.png?repo_id=575865240&limit=365&image_size=auto&color_scheme=light">
</v-img>
</v-card>
</v-col>
</v-row>
</v-container>
</section>
<span class="mt-2" style="color: #777;">A project out of interests and loves </span>
<span style="color: #777; margin-left: 32px; margin-right: 32px" class="mt-4">By <a
href="https://soulter.top">Soulter</a>, <a
href="https://github.com/Soulter/AstrBot/graphs/contributors">AstrBot Contributors</a>
and <a href="https://github.com/Soulter/AstrBot_Plugins_Collection/graphs/contributors">AstrBot
Plugin Authors</a>
</span>
<!-- Copy-paste in your Readme.md file -->
<img style="margin-top: 16px; width: 50%; max-width: 500px; margin-left: 32px; margin-right: 32px"
alt="Active Contributors of Soulter/AstrBot - Last 28 days"
src="https://next.ossinsight.io/widgets/official/compose-recent-active-contributors/thumbnail.png?repo_id=575865240&limit=365&image_size=auto&color_scheme=light">
<img style="margin-top: 16px; width: 50%; max-width: 500px; margin-left: 32px; margin-right: 32px"
alt="Active Contributors of Soulter/AstrBot - Last 28 days"
src="https://next.ossinsight.io/widgets/official/analyze-repo-stars-map/thumbnail.png?activity=stars&repo_id=575865240&image_size=auto&color_scheme=light
">
<!-- Made with [OSS Insight](https://ossinsight.io/) -->
<v-btn class="text-primary mt-8" @click="open('https://github.com/Soulter/AstrBot')"
color="lightprimary" variant="flat" rounded="sm">
Star 这个项目! 🌟
</v-btn>
<v-btn class="text-primary mt-4" @click="open('https://github.com/Soulter/AstrBot/issues')"
color="lightprimary" variant="flat" rounded="sm">
有使用问题或者功能建议提交 Issue
</v-btn>
<!-- Stats Section -->
<section class="stats-section">
<v-container>
<v-row justify="center" align="center" class="flex-md-row-reverse">
<v-col cols="12" md="6" class="pl-md-8 stats-info">
<h2 class="text-h4 font-weight-medium">全球部署</h2>
<div class="license-container mt-8">
<img v-bind="props" src="https://www.gnu.org/graphics/agplv3-with-text-100x42.png" style="cursor: pointer;"/>
<p class="text-caption mt-2" style="color: var(--v-theme-secondaryText);">AstrBot 采用 AGPL v3 协议开源</p>
</div>
</v-col>
<v-col cols="12" md="6">
<v-card variant="outlined" class="overflow-hidden" elevation="2">
<v-img v-if="useCustomizerStore().uiTheme==='PurpleThemeDark'"
alt="Stars Map of Soulter/AstrBot"
src="https://next.ossinsight.io/widgets/official/analyze-repo-stars-map/thumbnail.png?activity=stars&repo_id=575865240&image_size=auto&color_scheme=dark">
</v-img>
<v-img v-else
alt="Stars Map of Soulter/AstrBot"
src="https://next.ossinsight.io/widgets/official/analyze-repo-stars-map/thumbnail.png?activity=stars&repo_id=575865240&image_size=auto&color_scheme=light">
</v-img>
</v-card>
</v-col>
</v-row>
</v-container>
</section>
</div>
</v-card-text>
</v-card>
</template>
<script>
import {useCustomizerStore} from "@/stores/customizer";
export default {
name: 'AboutPage',
data() {
@@ -59,26 +99,141 @@ export default {
},
methods: {
useCustomizerStore,
open(url) {
window.open(url, '_blank');
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
<style scoped>
.about-wrapper {
min-height: 100%;
}
to {
opacity: 1;
}
.hero-section {
padding: 40px 20px;
background: linear-gradient(to right bottom, rgba(255,255,255,0.7), rgba(240,240,250,0.3));
display: flex;
justify-content: center;
align-items: center;
text-align: center;
}
.logo-title-container {
display: flex;
align-items: center;
flex-direction: row;
max-width: 900px;
gap: 20px;
}
.logo-container {
cursor: pointer;
transition: all 0.3s ease;
flex-shrink: 0;
}
.logo-container:hover {
transform: scale(1.05);
}
.title-container {
text-align: left;
}
.contributors-section, .stats-section {
padding: 60px 20px;
}
.contributors-section {
background-color: var(--v-theme-containerBg, #f9f9fb);
}
.contributors-info, .stats-info {
display: flex;
flex-direction: column;
justify-content: center;
}
.custom-link {
display: inline-block;
padding: 5px 0;
position: relative;
color: var(--v-primary-base);
font-weight: 500;
}
.custom-link::after {
content: '';
position: absolute;
width: 100%;
transform: scaleX(0);
height: 2px;
bottom: 0;
left: 0;
background-color: var(--v-primary-base);
transform-origin: bottom right;
transition: transform 0.25s ease-out;
}
.custom-link:hover::after {
transform: scaleX(1);
transform-origin: bottom left;
}
.license-container {
display: flex;
flex-direction: column;
align-items: flex-start;
}
.action-buttons {
display: flex;
margin-top: 24px;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
@media (max-width: 960px) {
.logo-title-container {
flex-direction: column;
text-align: center;
}
.title-container {
text-align: center;
}
.action-buttons {
justify-content: center;
}
.license-container {
align-items: center;
}
.contributors-section, .stats-section {
padding: 40px 20px;
}
}
@media (max-width: 600px) {
.action-buttons {
flex-direction: column;
gap: 12px;
}
.action-buttons .v-btn + .v-btn {
margin-left: 0 !important;
}
}
</style>

View File

@@ -0,0 +1,80 @@
<template>
<v-card style="height: 100%; width: 100%;">
<v-card-text class="pa-4" style="height: 100%;">
<v-container fluid class="d-flex flex-column" style="height: 100%;">
<div style="margin-bottom: 32px;">
<h1 class="gradient-text">The Alkaid Project.</h1>
<small style="color: #a3a3a3;">AstrBot Alpha 项目</small>
</div>
<div style="display: flex; gap: 8px; margin-bottom: 16px;">
<v-btn size="large" :variant="isActive('knowledge-base') ? 'flat' : 'tonal'"
:color="isActive('knowledge-base') ? '#9b72cb' : ''" rounded="lg"
@click="navigateTo('knowledge-base')">
<v-icon start>mdi-text-box-search</v-icon>
知识库
</v-btn>
<v-btn size="large" :variant="isActive('long-term-memory') ? 'flat' : 'tonal'"
:color="isActive('long-term-memory') ? '#9b72cb' : ''" rounded="lg"
@click="navigateTo('long-term-memory')">
<v-icon start>mdi-dots-hexagon</v-icon>
长期记忆层
</v-btn>
<v-btn size="large" :variant="isActive('other') ? 'flat' : 'tonal'"
:color="isActive('other') ? '#9b72cb' : ''" rounded="lg"
@click="navigateTo('other')">
<v-icon start>mdi-tools</v-icon>
...
</v-btn>
</div>
<div id="sub-view" class="flex-grow-1" style="max-height: 100%;">
<router-view></router-view>
</div>
</v-container>
</v-card-text>
</v-card>
</template>
<script>
export default {
name: 'AlkaidPage',
components: {},
data() {
return {}
},
methods: {
navigateTo(tab) {
this.$router.push(`/alkaid/${tab}`);
},
isActive(tab) {
return this.$route.path.includes(`/alkaid/${tab}`);
}
},
mounted() {
// 如果在根路径 /alkaid默认跳转到知识库页面
if (this.$route.path === '/alkaid') {
this.navigateTo('knowledge-base');
}
}
}
</script>
<style scoped>
.gradient-text {
background: linear-gradient(74deg, #2abfe1 0, #9b72cb 25%, #b55908 50%, #d93025 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent;
font-weight: bold;
}
#subview {
display: flex;
flex-direction: column;
flex-grow: 1;
width: 100%;
height: 100%;
}
</style>

View File

@@ -0,0 +1,432 @@
<script setup>
import Graph from "graphology";
import Sigma from "sigma";
import ForceSupervisor from "graphology-layout-force/worker";
</script>
<template>
<v-card style="height: 100%; width: 100%;">
<v-card-text class="pa-4" style="height: 100%;">
<v-container fluid class="d-flex flex-column" style="height: 100%;">
<div style="margin-bottom: 32px;">
<h1 class="gradient-text">The Alkaid Project.</h1>
<small style="color: #a3a3a3;">AstrBot 实验性项目</small>
</div>
<div style="display: flex; gap: 8px; margin-bottom: 16px;">
<v-btn size="large" :variant="activeTab === 'long-term-memory' ? 'flat' : 'tonal'"
:color="activeTab === 'long-term-memory' ? '#9b72cb' : ''" rounded="lg"
@click="activeTab = 'long-term-memory'">
<v-icon start>mdi-dots-hexagon</v-icon>
长期记忆层
</v-btn>
<v-btn size="large" :variant="activeTab === 'other' ? 'flat' : 'tonal'"
:color="activeTab === 'other' ? '#9b72cb' : ''" rounded="lg" @click="activeTab = 'other'">
<v-icon start>mdi-dots-horizontal</v-icon>
其他
</v-btn>
</div>
<div v-if="activeTab === 'long-term-memory'" id="long-term-memory" class="flex-grow-1"
style="display: flex; flex-direction: row;">
<div id="graph-container" style="flex-grow: 1; width: 100%; border: 1px solid #eee; border-radius: 8px;">
</div>
<div id="graph-control-panel"
style="min-width: 450px; border: 1px solid #eee; border-radius: 8px; padding: 16px; margin-left: 16px;">
<div>
<span style="color: #333333;">可视化</span>
<div style="margin-top: 8px;">
<v-autocomplete v-model="searchUserId" :items="userIdList" variant="outlined"
label="筛选用户 ID"></v-autocomplete>
<v-btn color="primary" @click="onNodeSelect" variant="tonal" style="margin-top: 8px;">
<v-icon start>mdi-magnify</v-icon>
筛选
</v-btn>
<v-btn color="secondary" @click="resetFilter" variant="tonal"
style="margin-top: 8px; margin-left: 8px;">
<v-icon start>mdi-filter-remove</v-icon>
重置筛选
</v-btn>
</div>
<div style="margin-top: 16px;">
<v-btn color="primary" @click="refreshGraph" variant="tonal">
<v-icon start>mdi-refresh</v-icon>
刷新图形
</v-btn>
</div>
</div>
<v-divider class="my-4"></v-divider>
<div v-if="selectedNode" class="mt-4">
<h3>节点详情</h3>
<v-card variant="outlined" class="mt-2 pa-3">
<div v-if="selectedNode.id">
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">ID:</span>
<span>{{ selectedNode.id }}</span>
</div>
</div>
<div v-if="selectedNode._label">
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">类型:</span>
<span>{{ selectedNode._label }}</span>
</div>
</div>
<div v-if="selectedNode.name">
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">名称:</span>
<span>{{ selectedNode.name }}</span>
</div>
</div>
<div v-if="selectedNode.user_id">
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">用户ID:</span>
<span>{{ selectedNode.user_id }}</span>
</div>
</div>
<div v-if="selectedNode.ts">
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">时间戳:</span>
<span>{{ selectedNode.ts }}</span>
</div>
</div>
<div v-if="selectedNode.type">
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">类型:</span>
<span>{{ selectedNode.type }}</span>
</div>
</div>
</v-card>
</div>
<div v-if="graphStats" class="mt-4">
<h3>图形统计</h3>
<v-card variant="outlined" class="mt-2 pa-3">
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">节点数:</span>
<span>{{ graphStats.nodeCount }}</span>
</div>
<div class="d-flex justify-space-between">
<span class="text-subtitle-2">边数:</span>
<span>{{ graphStats.edgeCount }}</span>
</div>
</v-card>
</div>
</div>
</div>
<div v-if="activeTab === 'other'" class="flex-grow-1" style="display: flex; flex-direction: column;">
<div class="d-flex align-center justify-center"
style="flex-grow: 1; width: 100%; border: 1px solid #eee; border-radius: 8px;">
<v-icon size="64" color="grey-lighten-1">mdi-tools</v-icon>
<p class="text-h6 text-grey ml-4">功能开发中</p>
</div>
</div>
</v-container>
</v-card-text>
</v-card>
</template>
<script>
import axios from 'axios';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
export default {
name: 'AlkaidPage',
components: {
AstrBotConfig,
WaitingForRestart
},
data() {
return {
renderer: null,
graph: null,
layout: null,
activeTab: 'long-term-memory',
node_data: [],
edge_data: [],
searchUserId: null,
userIdList: [],
selectedNode: null,
graphStats: null,
nodeColors: {
'PhaseNode': '#4CAF50', // 绿色
'PassageNode': '#2196F3', // 蓝色
'FactNode': '#FF9800', // 橙色
'default': '#9C27B0' // 紫色作为默认
},
edgeColors: {
'_include_': '#607D8B',
'_related_': '#9E9E9E',
'default': '#BDBDBD'
},
isLoading: false
}
},
mounted() {
this.initSigma();
this.ltmGetGraph();
this.ltmGetUserIds();
},
beforeUnmount() {
if (this.renderer) {
this.renderer.kill();
}
if (this.layout) {
this.layout.stop();
}
},
watch: {
activeTab(newVal) {
if (newVal === 'long-term-memory') {
this.$nextTick(() => {
if (!this.renderer) {
this.initSigma();
}
});
} else {
if (this.renderer) {
this.renderer.kill();
this.renderer = null;
}
if (this.layout) {
this.layout.stop();
this.layout = null;
}
}
}
},
methods: {
ltmGetGraph(userId = null) {
this.isLoading = true;
const params = userId ? { user_id: userId } : {};
axios.get('/api/plug/alkaid/ltm/graph', { params })
.then(response => {
let nodes = response.data.data.nodes;
let edges = response.data.data.edges;
this.node_data = nodes;
this.edge_data = edges;
if (this.graph) {
this.graph.clear();
}
nodes.forEach(node => {
const nodeId = node[0];
const nodeData = node[1];
if (!this.graph.hasNode(nodeId)) {
const nodeType = nodeData._label || 'default';
const color = this.nodeColors[nodeType] || this.nodeColors['default'];
this.graph.addNode(nodeId, {
x: Math.random(),
y: Math.random(),
size: 5,
label: nodeData.name || nodeId.split('_')[0],
color: color,
originalData: nodeData
});
}
});
// 添加边
edges.forEach(edge => {
const sourceId = edge[0];
const targetId = edge[1];
const edgeData = edge[2];
if (this.graph.hasNode(sourceId) && this.graph.hasNode(targetId)) {
const edgeId = `${sourceId}->${targetId}`;
const relationType = edgeData.relation_type || 'default';
const color = this.edgeColors[relationType] || this.edgeColors['default'];
this.graph.addEdge(sourceId, targetId, {
size: 1,
color: color,
originalData: edgeData,
label: relationType,
type: "line"
});
} else {
console.warn(`Edge ${sourceId} -> ${targetId} has missing nodes.`);
}
});
this.updateGraphStats();
console.log('Graph initialized with', nodes.length, 'nodes and', edges.length, 'edges');
})
.catch(error => {
console.error('Error fetching graph data:', error);
})
.finally(() => {
this.isLoading = false;
});
if (this.layout) {
this.layout.start();
}
},
ltmGetUserIds() {
axios.get('/api/plug/alkaid/ltm/user_ids')
.then(response => {
this.userIdList = response.data.data;
})
.catch(error => {
console.error('Error fetching user IDs:', error);
});
},
updateGraphStats() {
if (this.graph) {
this.graphStats = {
nodeCount: this.graph.order,
edgeCount: this.graph.size
};
}
},
refreshGraph() {
this.ltmGetGraph(this.searchUserId);
},
onNodeSelect() {
console.log('Selected user ID:', this.searchUserId);
if (!this.searchUserId || !this.graph) return;
// 使用API的user_id参数筛选数据
this.ltmGetGraph(this.searchUserId);
},
resetFilter() {
this.searchUserId = null;
this.ltmGetGraph();
},
initSigma() {
const container = document.getElementById("graph-container");
if (!container) return;
if (this.renderer) {
this.renderer.kill();
this.renderer = null;
}
if (this.layout) {
this.layout.stop();
this.layout = null;
}
const graph = new Graph({
multi: true,
});
const layout = new ForceSupervisor(graph, {
isNodeFixed: (_, attr) => attr.highlighted, settings: {
gravity: 0.0001,
repulsion: 0.001
}
});
layout.start();
this.layout = layout;
this.graph = graph;
const renderer = new Sigma(graph, container, {
minCameraRatio: 0.01,
maxCameraRatio: 2,
labelRenderedSizeThreshold: 1,
renderLabels: true,
renderEdgeLabels: true,
labelSize: 14,
labelColor: "#333333",
});
this.renderer = renderer;
let draggedNode = null;
let isDragging = false;
renderer.on("downNode", (e) => {
isDragging = true;
draggedNode = e.node;
graph.setNodeAttribute(draggedNode, "highlighted", true);
if (!renderer.getCustomBBox()) renderer.setCustomBBox(renderer.getBBox());
});
renderer.on("moveBody", ({ event }) => {
if (!isDragging || !draggedNode) return;
const pos = renderer.viewportToGraph(event);
graph.setNodeAttribute(draggedNode, "x", pos.x);
graph.setNodeAttribute(draggedNode, "y", pos.y);
event.preventSigmaDefault();
event.original.preventDefault();
event.original.stopPropagation();
});
const handleUp = () => {
if (draggedNode) {
graph.removeNodeAttribute(draggedNode, "highlighted");
}
isDragging = false;
draggedNode = null;
};
renderer.on("upNode", handleUp);
renderer.on("upStage", handleUp);
renderer.on("clickNode", (e) => {
const nodeId = e.node;
const nodeAttributes = graph.getNodeAttributes(nodeId);
this.selectedNode = nodeAttributes.originalData;
});
renderer.on("clickStage", () => {
this.selectedNode = null;
});
},
getRandomColor() {
const letters = '0123456789ABCDEF';
let color = '#';
for (let i = 0; i < 6; i++) {
color += letters[Math.floor(Math.random() * 16)];
}
return color;
}
},
}
</script>
<style scoped>
.gradient-text {
background: linear-gradient(74deg, #2abfe1 0, #9b72cb 25%, #b55908 50%, #d93025 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent;
font-weight: bold;
}
#graph-container {
position: relative;
background-color: #f2f6f9;
overflow: hidden;
min-height: 200px;
}
#graph-container:hover {
cursor: pointer;
}
.memory-header {
padding: 0 8px;
}
</style>

View File

@@ -0,0 +1,36 @@
<script setup>
import ChatPage from './ChatPage.vue';
</script>
<template>
<div style="height: 100%; width: 100%; display: flex; flex-direction: column; align-items: center; justify-content: center;">
<div id="container">
<ChatPage chatbox-mode="true"></ChatPage>
</div>
</div>
</template>
<style scoped>
#container {
width: 100%;
height: 100%;
}
@media (min-width: 768px) {
#container {
min-width: 600px;
min-height: 370px;
max-width: 1100px;
max-height: 860px;
padding: 36px;
}
}
@media (max-width: 767px) {
#container {
width: 100%;
height: 100%;
padding: 0;
}
}
</style>

Some files were not shown because too many files have changed in this diff Show More