Compare commits
58 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
685c0a106a | ||
|
|
7f539090dd | ||
|
|
2089273f95 | ||
|
|
838bb4c7ad | ||
|
|
637acd1a12 | ||
|
|
03fa9a847f | ||
|
|
d488c88e78 | ||
|
|
baae842210 | ||
|
|
ec1fb838b6 | ||
|
|
13281179df | ||
|
|
276a42c9a1 | ||
|
|
7a70a730ba | ||
|
|
d0fe59631c | ||
|
|
106892e933 | ||
|
|
19543a41b3 | ||
|
|
b172b760ab | ||
|
|
4b5d49cb41 | ||
|
|
3fd35b6058 | ||
|
|
5f86c4ab99 | ||
|
|
c94a7f6629 | ||
|
|
7d6beb4141 | ||
|
|
e2117e690a | ||
|
|
fb791290e2 | ||
|
|
5dd1488b5d | ||
|
|
529cd64d82 | ||
|
|
d2bd3e8da8 | ||
|
|
e42ce7dd86 | ||
|
|
40709462ee | ||
|
|
2ad6c01a4d | ||
|
|
70c12e788e | ||
|
|
1713791c90 | ||
|
|
9aa23fd412 | ||
|
|
e4ba09cd93 | ||
|
|
171fdf1fbc | ||
|
|
01f4e0b961 | ||
|
|
be2d5a91c7 | ||
|
|
a1d89d9478 | ||
|
|
98d1dc3b65 | ||
|
|
b80eb3acc0 | ||
|
|
05ccc1995b | ||
|
|
0de244889e | ||
|
|
e6c5c3a493 | ||
|
|
164aa2ccd2 | ||
|
|
f1599e26b3 | ||
|
|
ed64a4d32d | ||
|
|
2ee4b431d4 | ||
|
|
cd8a73ed19 | ||
|
|
e6c985ce4e | ||
|
|
a20446aeb9 | ||
|
|
7b23d76559 | ||
|
|
8315cf5818 | ||
|
|
ed16265bde | ||
|
|
dff205faf6 | ||
|
|
9aae8aee0c | ||
|
|
7c818ced2b | ||
|
|
218e887558 | ||
|
|
a68860b35a | ||
|
|
82d4d43383 |
106
README.md
106
README.md
@@ -1,6 +1,8 @@
|
||||
|
||||
<p align="center">
|
||||
<img width=200 src="https://github.com/user-attachments/assets/3dd6a669-0830-4db4-b821-c8b279ea19a6"/>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/de10f24d-cd64-433a-90b8-16c0a60de24a" width=500>
|
||||
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -15,79 +17,92 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||
</a>
|
||||
|
||||
<a href="https://astrbot.lwl.lol/">查看文档</a> |
|
||||
<a href="https://astrbot.app/">查看文档</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
|
||||
## ✨ 多消息平台部署
|
||||
## ✨ 主要功能
|
||||
|
||||
1. QQ 群、QQ 频道、微信个人号、Telegram。
|
||||
2. 支持文本转图片,Markdown 渲染。
|
||||
|
||||
## ✨ 多 LLM 配置
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat、VChat)、Telegram。后续将支持钉钉、飞书、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
|
||||
1. 适配 OpenAI API,支持接入 Gemini、GPT、Llama、Claude、DeepSeek、GLM 等各种大语言模型。
|
||||
2. 支持 OneAPI 等分发平台。
|
||||
3. 支持 LLMTuner 载入微调模型。
|
||||
4. 支持 Ollama 载入自部署模型。
|
||||
4. 支持网页搜索(Web Search)、自然语言待办提醒。
|
||||
5. 支持 Whisper 语音转文字
|
||||
> [!TIP]
|
||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM,因此无法在聊天页使用大模型。
|
||||
|
||||
## ✨ 管理面板
|
||||
## ✨ 使用方式
|
||||
|
||||
1. 支持可视化修改配置
|
||||
2. 日志实时查看
|
||||
3. 简单的信息统计
|
||||
4. 插件管理
|
||||
#### Docker 部署
|
||||
|
||||
## ✨ 支持 Dify
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流
|
||||
#### Windows 一键安装器部署
|
||||
|
||||
## ✨ 代码执行器(Beta)
|
||||
需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||
|
||||
基于 Docker 的沙箱化代码执行器(Beta 测试中)
|
||||
|
||||
> [!NOTE]
|
||||
> 文件输入/输出目前仅支持 Napcat(QQ)
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
</div>
|
||||
|
||||
## ✨ 云部署
|
||||
#### Replit 部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
|
||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
||||
|
||||
#### 手动部署
|
||||
|
||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
|
||||
| 平台 | 支持性 | 详情 | 消息类型 |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| QQ 官方API | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| 飞书 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
|
||||
对于新功能的添加,请先通过 Issue 讨论。
|
||||
|
||||
## 🔭 展望
|
||||
|
||||
1. 更强大的 Agent 系统。
|
||||
2. 打造插件工作流平台。
|
||||
|
||||
## ✨ Support
|
||||
## 🌟 支持
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
> [!NOTE]
|
||||
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
@@ -111,6 +126,13 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
|
||||
</div>
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
|
||||
<!-- ## ✨ ATRI [Beta 测试]
|
||||
|
||||
@@ -122,3 +144,5 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
_アトリは、高性能ですから!_
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
@@ -8,7 +7,6 @@ from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"personalities",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"sp"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.5"
|
||||
VERSION = "3.4.12"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -22,6 +22,9 @@ DEFAULT_CONFIG = {
|
||||
"id_whitelist_log": True,
|
||||
"wl_ignore_admin_on_group": True,
|
||||
"wl_ignore_admin_on_friend": True,
|
||||
"reply_with_mention": False,
|
||||
"reply_with_quote": False,
|
||||
"path_mapping": []
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
@@ -30,12 +33,16 @@ DEFAULT_CONFIG = {
|
||||
"web_search": False,
|
||||
"identifier": False,
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
||||
"default_personality": "default",
|
||||
"prompt_prefix": "",
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
@@ -56,6 +63,14 @@ DEFAULT_CONFIG = {
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": "",
|
||||
"knowledge_db": {},
|
||||
"persona": [
|
||||
{
|
||||
"name": "default",
|
||||
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -85,6 +100,15 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
},
|
||||
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
"type": "gewechat",
|
||||
"enable": False,
|
||||
"base_url": "http://localhost:2531",
|
||||
"nickname": "soulter",
|
||||
"host": "localhost",
|
||||
"port": 11451,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"id": {
|
||||
@@ -170,7 +194,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"enable_id_white_list": {
|
||||
"description": "启用 ID 白名单",
|
||||
"type": "bool"
|
||||
"type": "bool",
|
||||
},
|
||||
"id_whitelist": {
|
||||
"description": "ID 白名单",
|
||||
@@ -191,6 +215,22 @@ CONFIG_METADATA_2 = {
|
||||
"description": "管理员私聊消息无视 ID 白名单",
|
||||
"type": "bool",
|
||||
},
|
||||
"reply_with_mention": {
|
||||
"description": "回复时 @ 发送者",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人回复消息时会 @ 发送者。实际效果以具体的平台适配器为准。",
|
||||
},
|
||||
"reply_with_quote": {
|
||||
"description": "回复时引用消息",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人回复消息时会引用原消息。实际效果以具体的平台适配器为准。",
|
||||
},
|
||||
"path_mapping": {
|
||||
"description": "路径映射",
|
||||
"type": "list",
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
}
|
||||
},
|
||||
},
|
||||
"content_safety": {
|
||||
@@ -256,7 +296,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434",
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {
|
||||
"model": "llama3.1-8b",
|
||||
},
|
||||
@@ -334,14 +374,22 @@ CONFIG_METADATA_2 = {
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
}
|
||||
},
|
||||
"openai_tts(API)": {
|
||||
"id": "openai_tts",
|
||||
"type": "openai_tts_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "tts-1",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
@@ -367,7 +415,8 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如使用时出现了 404 报错,可以尝试在地址末尾加上 `/v1`。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"base_model_path": {
|
||||
"description": "基座模型路径",
|
||||
@@ -431,7 +480,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "Dify Workflow 输出变量名",
|
||||
"type": "string",
|
||||
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
@@ -442,7 +491,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
@@ -465,9 +514,9 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||
},
|
||||
"default_personality": {
|
||||
"description": "默认人格",
|
||||
"description": "默认采用的人格情景的名称",
|
||||
"type": "string",
|
||||
"hint": "默认人格(情境设置/System Prompt)文本。",
|
||||
"hint": "",
|
||||
},
|
||||
"prompt_prefix": {
|
||||
"description": "Prompt 前缀文本",
|
||||
@@ -476,6 +525,46 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
"description": "人格情景设置",
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"新人格情景": {
|
||||
"name": "",
|
||||
"prompt": "",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
},
|
||||
"tmpl_display_title": "name",
|
||||
"items": {
|
||||
"name": {
|
||||
"description": "人格名称",
|
||||
"type": "string",
|
||||
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"prompt": {
|
||||
"description": "设定(系统提示词)",
|
||||
"type": "text",
|
||||
"hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
|
||||
},
|
||||
"begin_dialogs": {
|
||||
"description": "预设对话",
|
||||
"type": "list",
|
||||
"items": {},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"description": "语音转文本(STT)",
|
||||
"type": "object",
|
||||
@@ -484,7 +573,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用语音转文本(STT)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个STT提供商",
|
||||
@@ -493,6 +582,23 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_tts_settings": {
|
||||
"description": "文本转语音(TTS)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用文本转语音(TTS)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个TTS提供商",
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"misc_config_group": {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import threading
|
||||
@@ -6,7 +7,6 @@ from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
@@ -81,12 +81,30 @@ class AstrBotCoreLifecycle:
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
|
||||
self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
|
||||
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
for task in tasks_:
|
||||
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
|
||||
|
||||
self.start_time = int(time.time())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||
for line in traceback.format_exc().split("\n"):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
async def start(self):
|
||||
self._load()
|
||||
logger.info("AstrBot 启动完成。")
|
||||
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
async def stop(self):
|
||||
|
||||
@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
|
||||
proxy: T.Optional[bool] = True
|
||||
timeout: T.Optional[int] = 0
|
||||
# 额外
|
||||
path: T.Optional[str] # 用这个
|
||||
path: T.Optional[str]
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
for k in _.keys():
|
||||
|
||||
@@ -139,7 +139,7 @@ class MessageEventResult(MessageChain):
|
||||
'''
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
|
||||
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
|
||||
'''设置事件处理的结果类型。
|
||||
|
||||
Args:
|
||||
@@ -148,5 +148,10 @@ class MessageEventResult(MessageChain):
|
||||
self.result_content_type = typ
|
||||
return self
|
||||
|
||||
def is_llm_result(self) -> bool:
|
||||
'''是否为 LLM 结果。
|
||||
'''
|
||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
|
||||
CommandResult = MessageEventResult
|
||||
@@ -5,7 +5,7 @@ from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Record
|
||||
from astrbot.core.message.components import Plain, Record, Image
|
||||
|
||||
@register_stage
|
||||
class PreProcessStage(Stage):
|
||||
@@ -16,24 +16,39 @@ class PreProcessStage(Stage):
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
|
||||
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
|
||||
self.platform_settings: dict = self.config.get('platform_settings', {})
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''在处理事件之前的预处理'''
|
||||
|
||||
# 路径映射
|
||||
if mappings := self.platform_settings.get('path_mapping', []):
|
||||
# 支持 Record,Image 消息段的路径映射。
|
||||
message_chain = event.get_messages()
|
||||
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, (Record, Image)) and component.url:
|
||||
for mapping in mappings:
|
||||
from_, to_ = mapping.split(":")
|
||||
from_ = from_.removesuffix("/")
|
||||
to_ = to_.removesuffix("/")
|
||||
|
||||
url = component.url.removeprefix("file://")
|
||||
if url.startswith(from_):
|
||||
component.url = url.replace(from_, to_, 1)
|
||||
logger.debug(f"路径映射: {url} -> {component.url}")
|
||||
message_chain[idx] = component
|
||||
|
||||
# STT
|
||||
if self.stt_settings.get('enable', False):
|
||||
# STT 处理
|
||||
# TODO: 独立
|
||||
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.path:
|
||||
|
||||
path = component.path
|
||||
|
||||
if isinstance(component, Record) and component.url:
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
@@ -46,7 +61,7 @@ class PreProcessStage(Stage):
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
|
||||
@@ -17,6 +17,13 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
|
||||
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
|
||||
|
||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
@@ -30,10 +37,10 @@ class LLMRequestSubStage(Stage):
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
||||
if self.provider_wake_prefix:
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix):]
|
||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
@@ -44,7 +51,7 @@ class LLMRequestSubStage(Stage):
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
req.contexts = session_provider_context if session_provider_context else []
|
||||
|
||||
if not req.prompt:
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件。
|
||||
@@ -98,5 +105,5 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
|
||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
|
||||
return
|
||||
@@ -1,10 +1,12 @@
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@@ -13,6 +15,9 @@ class ResultDecorateStage:
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
|
||||
self.t2i = ctx.astrbot_config['t2i']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -29,9 +34,28 @@ class ResultDecorateStage:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
result.chain.insert(0, Plain(self.reply_prefix))
|
||||
|
||||
# TTS
|
||||
if self.use_tts and result.is_llm_result():
|
||||
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain_str += " " + comp.text
|
||||
else:
|
||||
break
|
||||
if plain_str:
|
||||
try:
|
||||
audio_path = await tts_provider.get_audio(plain_str)
|
||||
logger.info("TTS 结果: " + audio_path)
|
||||
if audio_path:
|
||||
result.chain = [Record(file=audio_path, url=audio_path)]
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
|
||||
# 文本转图片
|
||||
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
@@ -48,4 +72,10 @@ class ResultDecorateStage:
|
||||
if time.time() - render_start > 3:
|
||||
logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
|
||||
if url:
|
||||
result.chain = [Image.fromURL(url)]
|
||||
result.chain = [Image.fromURL(url)]
|
||||
|
||||
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
|
||||
result.chain.insert(0, At(qq=event.get_sender_id()))
|
||||
|
||||
if self.reply_with_quote:
|
||||
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
||||
@@ -25,6 +25,8 @@ class PlatformManager():
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
@dataclass
|
||||
class PlatformMetadata():
|
||||
name: str # 平台的名称
|
||||
description: str # 平台的描述
|
||||
name: str
|
||||
'''平台的名称'''
|
||||
description: str
|
||||
'''平台的描述'''
|
||||
|
||||
default_config_tmpl: dict = None # 平台的默认配置模板
|
||||
default_config_tmpl: dict = None
|
||||
'''平台的默认配置模板'''
|
||||
adapter_display_name: str = None
|
||||
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''
|
||||
@@ -7,7 +7,12 @@ platform_registry: List[PlatformMetadata] = []
|
||||
platform_cls_map: Dict[str, Type] = {}
|
||||
'''维护了平台适配器名称和适配器类的映射'''
|
||||
|
||||
def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
|
||||
def register_platform_adapter(
|
||||
adapter_name: str,
|
||||
desc: str,
|
||||
default_config_tmpl: dict = None,
|
||||
adapter_display_name: str = None
|
||||
):
|
||||
'''用于注册平台适配器的带参装饰器。
|
||||
|
||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||
@@ -22,11 +27,14 @@ def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl:
|
||||
default_config_tmpl['type'] = adapter_name
|
||||
if 'enable' not in default_config_tmpl:
|
||||
default_config_tmpl['enable'] = False
|
||||
if 'id' not in default_config_tmpl:
|
||||
default_config_tmpl['id'] = adapter_name
|
||||
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
default_config_tmpl=default_config_tmpl
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -3,7 +3,7 @@ import random
|
||||
import asyncio
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
|
||||
@@ -20,15 +20,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
d = segment.toDict()
|
||||
if isinstance(segment, Plain):
|
||||
d['type'] = 'text'
|
||||
if isinstance(segment, Image):
|
||||
if isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
if segment.file and segment.file.startswith("file:///"):
|
||||
image_base64 = file_to_base64(segment.file[8:])
|
||||
bs64_data = file_to_base64(segment.file[8:])
|
||||
image_file_path = segment.file[8:]
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(segment.file)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
d['data']['file'] = image_base64
|
||||
bs64_data = file_to_base64(image_file_path)
|
||||
else:
|
||||
bs64_data = file_to_base64(segment.file)
|
||||
d['data'] = {
|
||||
'file': bs64_data,
|
||||
}
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
|
||||
class AiocqhttpAdapter(Platform):
|
||||
@@ -81,22 +82,36 @@ class AiocqhttpAdapter(Platform):
|
||||
if t == 'text':
|
||||
message_str += m['data']['text'].strip()
|
||||
elif t == 'file':
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
|
||||
if not ret.get('file', None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret['file']):
|
||||
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
|
||||
if m['data']['url'] and m['data']['url'].startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
|
||||
file_name = m['data'].get('file_name', "file")
|
||||
path = os.path.join("data/temp", file_name)
|
||||
await download_file(m['data']['url'], path)
|
||||
|
||||
m['data'] = {
|
||||
"file": ret['file'],
|
||||
"name": ret['file_name']
|
||||
"file": path,
|
||||
"name": file_name
|
||||
}
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
else:
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
|
||||
if not ret.get('file', None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret['file']):
|
||||
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
|
||||
|
||||
m['data'] = {
|
||||
"file": ret['file'],
|
||||
"name": ret['file_name']
|
||||
}
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
a = ComponentTypes[t](**m['data']) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
345
astrbot/core/platform/sources/gewechat/client.py
Normal file
345
astrbot/core/platform/sources/gewechat/client.py
Normal file
@@ -0,0 +1,345 @@
|
||||
import threading
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import quart
|
||||
import base64
|
||||
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At, Record
|
||||
from astrbot.api import logger, sp
|
||||
from .downloader import GeweDownloader
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
|
||||
class SimpleGewechatClient():
|
||||
'''针对 Gewechat 的简单实现。
|
||||
|
||||
@author: Soulter
|
||||
@website: https://github.com/Soulter
|
||||
'''
|
||||
def __init__(self, base_url: str, nickname: str, host: str, port: int, event_queue: asyncio.Queue):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith('/'):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
|
||||
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
logger.info(f"Gewechat API: {self.base_url}")
|
||||
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
|
||||
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.multimedia_downloader = None
|
||||
|
||||
async def get_token_id(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||
json_blob = await resp.json()
|
||||
self.token = json_blob['data']
|
||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||
self.headers = {
|
||||
"X-GEWE-TOKEN": self.token
|
||||
}
|
||||
|
||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||
type_name = data['TypeName']
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
abm = AstrBotMessage()
|
||||
d = data['Data']
|
||||
|
||||
from_user_name = d['FromUserName']['string'] # 消息来源
|
||||
d['to_wxid'] = from_user_name # 用于发信息
|
||||
|
||||
abm.message_id = str(d.get('MsgId'))
|
||||
abm.session_id = from_user_name
|
||||
abm.self_id = data['Wxid'] # 机器人的 wxid
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d['Content']['string'] # 消息内容
|
||||
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(':\n')
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
if '\u2005' in content:
|
||||
# at
|
||||
content = content.split('\u2005')[1]
|
||||
abm.group_id = from_user_name
|
||||
# at
|
||||
msg_source = d['MsgSource']
|
||||
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
|
||||
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
|
||||
at_me = True
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
|
||||
abm.message = []
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
user_real_name = d['PushContent'].split(' : ')[0] \
|
||||
.replace('在群聊中@了你', '') \
|
||||
.replace('在群聊中发了一段语音', '') # 真实昵称
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
abm.message_str = ""
|
||||
# 不同消息类型
|
||||
match d['MsgType']:
|
||||
case 1:
|
||||
# 文本消息
|
||||
abm.message.append(Plain(content))
|
||||
abm.message_str = content
|
||||
case 3:
|
||||
# 图片消息
|
||||
file_url = await self.multimedia_downloader.download_image(
|
||||
self.appid,
|
||||
content
|
||||
)
|
||||
logger.debug(f"下载图片: {file_url}")
|
||||
file_path = await download_image_by_url(file_url)
|
||||
abm.message.append(Image(file=file_path, url=file_path))
|
||||
|
||||
case 34:
|
||||
# 语音消息
|
||||
# data = await self.multimedia_downloader.download_voice(
|
||||
# self.appid,
|
||||
# content,
|
||||
# abm.message_id
|
||||
# )
|
||||
# print(data)
|
||||
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
|
||||
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
|
||||
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(voice_data)
|
||||
abm.message.append(Record(file=file_path, url=file_path))
|
||||
|
||||
case _:
|
||||
logger.error(f"未实现的消息类型: {d['MsgType']}")
|
||||
return
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
return abm
|
||||
|
||||
async def callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
|
||||
if data.get('testMsg', None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = await self._convert(data)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"token": self.token,
|
||||
"callbackUrl": self.callback_url
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob['ret'] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
|
||||
|
||||
async def start_polling(self):
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
await self.server.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger_placeholder
|
||||
)
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while not self.event_queue.closed:
|
||||
await asyncio.sleep(1)
|
||||
logger.info("gewechat 适配器已关闭。")
|
||||
|
||||
async def check_online(self, appid: str):
|
||||
# /login/checkOnline
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkOnline",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"appId": appid
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob['data']
|
||||
|
||||
async def logout(self):
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/logout",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"appId": self.appid
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"登出结果: {json_blob}")
|
||||
|
||||
async def login(self):
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
|
||||
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
logger.info(f"APPID: {self.appid} 已在线")
|
||||
return
|
||||
|
||||
payload = {
|
||||
"appId": self.appid
|
||||
}
|
||||
|
||||
if self.appid:
|
||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/getLoginQrCode",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob['ret'] != 200:
|
||||
raise Exception(f"获取二维码失败: {json_blob}")
|
||||
qr_data = json_blob['data']['qrData']
|
||||
qr_uuid = json_blob['data']['uuid']
|
||||
appid = json_blob['data']['appId']
|
||||
logger.info(f"APPID: {appid}")
|
||||
logger.warning(f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}")
|
||||
|
||||
# 执行登录
|
||||
retry_cnt = 64
|
||||
payload.update({
|
||||
"uuid": qr_uuid,
|
||||
"appId": appid
|
||||
})
|
||||
while retry_cnt > 0:
|
||||
retry_cnt -= 1
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkLogin",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"检查登录状态: {json_blob}")
|
||||
status = json_blob['data']['status']
|
||||
nickname = json_blob['data'].get('nickName', '')
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
async def post_text(self, to_wxid, content: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postText",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送消息结果: {json_blob}")
|
||||
|
||||
async def post_image(self, to_wxid, image_url: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"imgUrl": image_url,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postImage",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送图片结果: {json_blob}")
|
||||
|
||||
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"voiceUrl": voice_url,
|
||||
"voiceDuration": voice_duration
|
||||
}
|
||||
|
||||
logger.debug(f"发送语音: {payload}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postVoice",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"发送语音结果: {json_blob}")
|
||||
51
astrbot/core/platform/sources/gewechat/downloader.py
Normal file
51
astrbot/core/platform/sources/gewechat/downloader.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from astrbot import logger
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
class GeweDownloader():
|
||||
def __init__(self, base_url: str, download_base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.download_base_url = download_base_url
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-GEWE-TOKEN": token
|
||||
}
|
||||
|
||||
async def _post_json(self, baseurl: str, route: str, payload: dict):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{baseurl}{route}",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
return await resp.read()
|
||||
|
||||
async def download_voice(self, appid: str, xml: str, msg_id: str):
|
||||
payload = {
|
||||
"appId": appid,
|
||||
"xml": xml,
|
||||
"msgId": msg_id
|
||||
}
|
||||
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
|
||||
|
||||
async def download_image(self, appid: str, xml: str) -> str:
|
||||
'''返回一个可下载的 URL'''
|
||||
choices = [2, 3] # 2:常规图片 3:缩略图
|
||||
|
||||
for choice in choices:
|
||||
try:
|
||||
payload = {
|
||||
"appId": appid,
|
||||
"xml": xml,
|
||||
"type": choice
|
||||
}
|
||||
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
|
||||
json_blob = json.loads(data)
|
||||
if 'fileUrl' in json_blob['data']:
|
||||
return self.download_base_url + json_blob['data']['fileUrl']
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(f"gewe download image: {e}")
|
||||
continue
|
||||
|
||||
raise Exception("无法下载图片")
|
||||
102
astrbot/core/platform/sources/gewechat/gewechat_event.py
Normal file
102
astrbot/core/platform/sources/gewechat/gewechat_event.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import wave
|
||||
import uuid
|
||||
import os
|
||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from .client import SimpleGewechatClient
|
||||
|
||||
def get_wav_duration(file_path):
|
||||
with wave.open(file_path, 'rb') as wav_file:
|
||||
file_size = os.path.getsize(file_path)
|
||||
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
|
||||
if n_frames == 2147483647:
|
||||
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
|
||||
else:
|
||||
duration = n_frames / float(framerate)
|
||||
return duration
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleGewechatClient
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(message: MessageChain, user_name: str):
|
||||
pass
|
||||
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
to_wxid = self.message_obj.raw_message.get('to_wxid', None)
|
||||
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.post_text(to_wxid, comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
img_url = comp.file
|
||||
img_path = ""
|
||||
if img_url.startswith("file:///"):
|
||||
img_path = img_url[8:]
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
img_path = await download_image_by_url(comp.file)
|
||||
else:
|
||||
img_path = img_url
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
temp_directory = os.path.abspath('data/temp')
|
||||
img_path = os.path.abspath(img_path)
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{self.client.file_server_url}/{file_id}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await self.client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Record):
|
||||
# 默认已经存在 data/temp 中
|
||||
record_url = comp.file
|
||||
record_path = ""
|
||||
|
||||
if record_url.startswith("file:///"):
|
||||
record_path = record_url[8:]
|
||||
elif record_url.startswith("http"):
|
||||
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
|
||||
else:
|
||||
record_path = record_url
|
||||
|
||||
silk_path = f"data/temp/{uuid.uuid4()}.silk"
|
||||
duration = await wav_to_tencent_silk(record_path, silk_path)
|
||||
|
||||
print(f"duration: {duration}, {silk_path}")
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
|
||||
# temp_directory = os.path.abspath('data/temp')
|
||||
# record_path = os.path.abspath(record_path)
|
||||
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
|
||||
# with open(record_path, "rb") as f:
|
||||
# record_path = f"data/temp/{uuid.uuid4()}.wav"
|
||||
# with open(record_path, "wb") as f2:
|
||||
# f2.write(f.read())
|
||||
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{self.client.file_server_url}/{file_id}"
|
||||
await self.client.post_voice(to_wxid, record_url, duration*1000)
|
||||
await super().send(message)
|
||||
@@ -0,0 +1,93 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .gewechat_event import GewechatPlatformEvent
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.message.components import Plain
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||
class GewechatPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
self.client = None
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
to_wxid = session.session_id
|
||||
if "_" in to_wxid:
|
||||
# 群聊,开启了独立会话
|
||||
_, to_wxid = to_wxid.split("_")
|
||||
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.post_text(to_wxid, comp.text)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"gewechat",
|
||||
"基于 gewechat 的 Wechat 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config['base_url'],
|
||||
self.config['nickname'],
|
||||
self.config['host'],
|
||||
self.config['port'],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
return self._run()
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
|
||||
await self.client.start_polling()
|
||||
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if self.settingss['unique_session']:
|
||||
message.session_id = message.sender.user_id + "_" + message.group_id
|
||||
|
||||
message_event = GewechatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
@@ -80,4 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
elif i.file and i.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(i.file)
|
||||
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
|
||||
else:
|
||||
image_base64 = file_to_base64(i.file).replace("base64://", "")
|
||||
image_file_path = i.file
|
||||
return plain_text, image_base64, image_file_path
|
||||
@@ -38,11 +38,13 @@ class WebChatAdapter(Platform):
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
# abm.session_id = f"webchat!{username}!{cid}"
|
||||
plain = ""
|
||||
cid = session.session_id.split("!")[-1]
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain += comp.text
|
||||
web_chat_back_queue.put_nowait(plain)
|
||||
web_chat_back_queue.put_nowait((plain, cid))
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
|
||||
@@ -16,9 +16,11 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
return
|
||||
|
||||
cid = self.session_id.split("!")[-1]
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
web_chat_back_queue.put_nowait(comp.text)
|
||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
@@ -30,6 +32,10 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -17,6 +17,11 @@ class ProviderMetaData():
|
||||
'''提供商适配器描述.'''
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Type = None
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
'''平台的默认配置模板'''
|
||||
provider_display_name: str = None
|
||||
'''显示在 WebUI 配置页中的提供商名称,如空则是 type'''
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest():
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import traceback
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider, STTProvider
|
||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||
from .entites import ProviderType
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -13,19 +13,70 @@ class ProviderManager():
|
||||
self.providers_config: List = config['provider']
|
||||
self.provider_settings: dict = config['provider_settings']
|
||||
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
|
||||
self.persona_configs: list = config.get('persona', [])
|
||||
|
||||
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
|
||||
self.personas: List[Personality] = []
|
||||
self.selected_default_persona = None
|
||||
for persona in self.persona_configs:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
bd_processed = []
|
||||
mid_processed = ""
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
|
||||
continue
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append({
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None # 不持久化到 db
|
||||
})
|
||||
user_turn = not user_turn
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
|
||||
continue
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
role = "A" if user_turn else "B"
|
||||
mid_processed += f"{role}: {dialog}\n"
|
||||
if not user_turn:
|
||||
mid_processed += '\n'
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed=mid_processed
|
||||
)
|
||||
if persona['name'] == self.default_persona_name:
|
||||
self.selected_default_persona = persona
|
||||
self.personas.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
'''加载的 Speech To Text Provider 的实例'''
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
'''加载的 Text To Speech Provider 的实例'''
|
||||
self.llm_tools = llm_tools
|
||||
self.curr_provider_inst: Provider = None
|
||||
'''当前使用的 Provider 实例'''
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
'''当前使用的 Speech To Text Provider 实例'''
|
||||
self.curr_tts_provider_inst: TTSProvider = None
|
||||
'''当前使用的 Text To Speech Provider 实例'''
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
self.curr_kdb_name = ""
|
||||
kdb_cfg = config.get("knowledge_db", {})
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
@@ -56,6 +107,8 @@ class ProviderManager():
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
||||
continue
|
||||
@@ -72,8 +125,10 @@ class ProviderManager():
|
||||
continue
|
||||
selected_provider_id = sp.get("curr_provider")
|
||||
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
provider_enabled = self.provider_settings.get("enable", False)
|
||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
tts_enabled = self.provider_settings.get("enable", False)
|
||||
|
||||
provider_metadata = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
@@ -91,10 +146,28 @@ class ProviderManager():
|
||||
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.db_helper,
|
||||
self.provider_settings.get('persistant_history', True),
|
||||
self.selected_default_persona
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
@@ -114,11 +187,18 @@ class ProviderManager():
|
||||
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
if self.provider_stt_settings.get("enable"):
|
||||
if not self.curr_stt_provider_inst:
|
||||
|
||||
if stt_enabled and not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
|
||||
if tts_enabled and not self.curr_tts_provider_inst:
|
||||
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
|
||||
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
@@ -11,34 +11,62 @@ from dataclasses import dataclass
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: List[str] = []
|
||||
mood_imitation_dialogs: List[str] = []
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: List[dict]
|
||||
_mood_imitation_dialogs_processed: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta():
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
def __init__(self, provider_config: dict) -> None:
|
||||
super().__init__()
|
||||
self.model_name = ""
|
||||
self.provider_config = provider_config
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
|
||||
class Provider(abc.ABC):
|
||||
class Provider(AbstractProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
persistant_history: bool = True,
|
||||
db_helper: BaseDatabase = None
|
||||
db_helper: BaseDatabase = None,
|
||||
default_persona: Personality = None
|
||||
) -> None:
|
||||
self.model_name = ""
|
||||
'''当前使用的模型名称'''
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.session_memory = defaultdict(list)
|
||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
||||
|
||||
self.provider_config = provider_config
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality = Personality(prompt=provider_settings['default_personality'])
|
||||
'''维护了当前的使用的 persona,即人格。'''
|
||||
self.curr_personality: Personality = default_persona
|
||||
'''维护了当前的使用的 persona,即人格。可能为 None'''
|
||||
|
||||
self.db_helper = db_helper
|
||||
'''用于持久化的数据库操作对象。'''
|
||||
@@ -50,14 +78,6 @@ class Provider(abc.ABC):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
@@ -125,17 +145,11 @@ class Provider(abc.ABC):
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
|
||||
|
||||
class STTProvider():
|
||||
class STTProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@@ -143,19 +157,15 @@ class STTProvider():
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''获取音频的文本'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TTSProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获取当前使用的模型'''
|
||||
return self.provider_config.get("model", "")
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
@abc.abstractmethod
|
||||
async def get_audio(self, text: str) -> str:
|
||||
'''获取文本的音频,返回音频文件路径'''
|
||||
raise NotImplementedError()
|
||||
@@ -13,22 +13,35 @@ llm_tools = FuncCall()
|
||||
def register_provider_adapter(
|
||||
provider_type_name: str,
|
||||
desc: str,
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
|
||||
default_config_tmpl: dict = None,
|
||||
provider_display_name: str = None
|
||||
):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def decorator(cls):
|
||||
if provider_type_name in provider_cls_map:
|
||||
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
|
||||
|
||||
# 添加必备选项
|
||||
if default_config_tmpl:
|
||||
if 'type' not in default_config_tmpl:
|
||||
default_config_tmpl['type'] = provider_type_name
|
||||
if 'enable' not in default_config_tmpl:
|
||||
default_config_tmpl['enable'] = False
|
||||
if 'id' not in default_config_tmpl:
|
||||
default_config_tmpl['id'] = provider_type_name
|
||||
|
||||
pm = ProviderMetaData(
|
||||
type=provider_type_name,
|
||||
desc=desc,
|
||||
provider_type=provider_type,
|
||||
cls_type=cls
|
||||
cls_type=cls,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
provider_display_name=provider_display_name
|
||||
)
|
||||
provider_registry.append(pm)
|
||||
provider_cls_map[provider_type_name] = pm
|
||||
logger.debug(f"Provider {provider_type_name} 已注册")
|
||||
logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from .. import Provider, Personality
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -16,9 +16,10 @@ class ProviderDify(Provider):
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
default_persona: Personality=None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config, provider_settings, persistant_history, db_helper
|
||||
provider_config, provider_settings, persistant_history, db_helper, default_persona
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import aiohttp
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
@@ -18,7 +18,7 @@ class SimpleGoogleGenAIClient():
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession()
|
||||
self.client = aiohttp.ClientSession(trust_env=True)
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
@@ -60,9 +60,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
persistant_history = True,
|
||||
default_persona: Personality=None
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
@@ -130,6 +131,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
tool = None
|
||||
if tools:
|
||||
tool = tools.get_func_desc_google_genai_style()
|
||||
if not tool:
|
||||
tool = None
|
||||
|
||||
system_instruction = ""
|
||||
for message in payloads["messages"]:
|
||||
@@ -209,6 +212,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if '_no_save' in part:
|
||||
del part['_no_save']
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
@@ -217,15 +224,24 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
@@ -239,7 +255,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
self.session_memory[session_id] = [*contexts, new_record, {
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
from llmtuner.chat import ChatModel
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from .. import Provider, Personality
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -19,9 +19,10 @@ class LLMTunerModelLoader(Provider):
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config, provider_settings, persistant_history, db_helper
|
||||
provider_config, provider_settings, persistant_history, db_helper, default_persona
|
||||
)
|
||||
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
|
||||
provider_config["adapter_model_path"]
|
||||
@@ -61,20 +62,25 @@ class LLMTunerModelLoader(Provider):
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
if not contexts:
|
||||
query_context = [
|
||||
*self.session_memory[session_id],
|
||||
{"role": "user", "content": prompt},
|
||||
new_record,
|
||||
]
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
else:
|
||||
query_context = [*contexts, {"role": "user", "content": prompt}]
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(query_context):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
|
||||
if '_no_save' in context:
|
||||
del context['_no_save']
|
||||
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + query_context.pop(idx)["content"]
|
||||
|
||||
@@ -83,27 +89,37 @@ class LLMTunerModelLoader(Provider):
|
||||
"system": system_prompt,
|
||||
}
|
||||
if func_tool:
|
||||
conf["tools"] = func_tool
|
||||
tool_list = func_tool.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
conf['tools'] = tool_list
|
||||
|
||||
responses = await self.model.achat(**conf)
|
||||
|
||||
if session_id:
|
||||
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
self.session_memory[session_id].append(
|
||||
{"role": "user", "content": prompt}
|
||||
)
|
||||
self.session_memory[session_id].append(
|
||||
{"role": "assistant", "content": responses[-1].response_text}
|
||||
)
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
self.session_memory[session_id] = [
|
||||
*contexts,
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": responses[-1].response_text},
|
||||
]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
|
||||
return responses[-1].response_text
|
||||
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
|
||||
@@ -8,7 +7,7 @@ from openai._exceptions import NotFoundError
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
@@ -22,9 +21,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
persistant_history = True,
|
||||
default_persona: Personality = None
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
@@ -99,7 +99,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
payloads["tools"] = tools.get_func_desc_openai_style()
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
payloads['tools'] = tool_list
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
@@ -107,7 +109,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"completion: {completion.usage}")
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
@@ -150,6 +152,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if '_no_save' in part:
|
||||
del part['_no_save']
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**self.provider_config.get("model_config", {})
|
||||
@@ -157,15 +163,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
@@ -179,7 +195,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
self.session_memory[session_id] = [*contexts, new_record, {
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
|
||||
40
astrbot/core/provider/sources/openai_tts_api_source.py
Normal file
40
astrbot/core/provider/sources/openai_tts_api_source.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import uuid
|
||||
import os
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
|
||||
class ProviderOpenAITTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = provider_config.get("api_key", "")
|
||||
self.voice = provider_config.get("voice", "alloy")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
|
||||
async with self.client.audio.speech.with_streaming_response.create(
|
||||
model=self.model_name,
|
||||
voice=self.voice,
|
||||
response_format='wav',
|
||||
input=text
|
||||
) as response:
|
||||
with open(path, 'wb') as f:
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
return path
|
||||
@@ -1,12 +1,12 @@
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
|
||||
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
@@ -33,34 +33,6 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
|
||||
import wave
|
||||
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(input_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def _convert_silk(self, path: str) -> str:
|
||||
import pysilk
|
||||
filename = str(uuid.uuid4()) + '.wav'
|
||||
output_path = os.path.join('data/temp', filename)
|
||||
with open(path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
# tencent 我爱你
|
||||
input_data = input_data[1:]
|
||||
input_io = io.BytesIO(input_data)
|
||||
output_io = io.BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
await self._pcm_to_wav(output_io, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
@@ -73,20 +45,27 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''only supports mp3, mp4, mpeg, m4a, wav, webm'''
|
||||
is_tencent = False
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
if "multimedia.nt.qq.com.cn" in audio_url:
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
audio_url = await download_file(audio_url, path)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
audio_url = await self._convert_silk(audio_url)
|
||||
|
||||
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import whisper
|
||||
from ..provider import STTProvider
|
||||
@@ -8,7 +7,7 @@ from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
|
||||
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
@@ -34,34 +33,6 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
|
||||
import wave
|
||||
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(input_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def _convert_silk(self, path: str) -> str:
|
||||
import pysilk
|
||||
filename = str(uuid.uuid4()) + '.wav'
|
||||
output_path = os.path.join('data/temp', filename)
|
||||
with open(path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
# tencent 我爱你
|
||||
input_data = input_data[1:]
|
||||
input_io = io.BytesIO(input_data)
|
||||
output_io = io.BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
await self._pcm_to_wav(output_io, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
@@ -74,19 +45,28 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
is_tencent = False
|
||||
|
||||
if audio_url.startswith("http"):
|
||||
if "multimedia.nt.qq.com.cn" in audio_url:
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
audio_url = await download_file(audio_url, path)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
audio_url = await self._convert_silk(audio_url)
|
||||
|
||||
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||
return result['text']
|
||||
@@ -14,9 +14,10 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
persistant_history = True,
|
||||
default_persona = None
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
|
||||
super().__init__(provider_config, provider_settings, db_helper, persistant_history, default_persona)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
@@ -59,15 +60,23 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
"messages": context_query,
|
||||
**model_cfgs
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
@@ -20,6 +20,6 @@ class PermissionTypeFilter(HandlerFilter):
|
||||
if self.permission_type == PermissionType.ADMIN:
|
||||
if not event.is_admin():
|
||||
event.stop_event()
|
||||
raise ValueError("您没有权限执行此操作。")
|
||||
raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限执行此操作。")
|
||||
|
||||
return True
|
||||
|
||||
@@ -8,12 +8,14 @@ class PlatformAdapterType(enum.Flag):
|
||||
AIOCQHTTP = enum.auto()
|
||||
QQOFFICIAL = enum.auto()
|
||||
VCHAT = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT
|
||||
GEWECHAT = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT
|
||||
|
||||
ADAPTER_NAME_2_TYPE = {
|
||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||
"vchat": PlatformAdapterType.VCHAT
|
||||
"vchat": PlatformAdapterType.VCHAT,
|
||||
"gewechat": PlatformAdapterType.GEWECHAT
|
||||
}
|
||||
|
||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||
|
||||
@@ -7,7 +7,6 @@ import yaml
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from pip import main as pip_main
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
|
||||
@@ -11,7 +11,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
|
||||
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
try:
|
||||
|
||||
@@ -6,6 +6,8 @@ import time
|
||||
import aiohttp
|
||||
import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -41,21 +43,21 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
return False
|
||||
|
||||
|
||||
def save_temp_img(img: Image) -> str:
|
||||
def save_temp_img(img: Union[Image.Image, str]) -> str:
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
# 获得文件创建时间,清除超过1小时的
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
for f in os.listdir("data/temp"):
|
||||
path = os.path.join("data/temp", f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600:
|
||||
if time.time() - ctime > 3600*12:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
print(f"清除临时文件失败: {e}")
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = int(time.time())
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
p = f"data/temp/{timestamp}.jpg"
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
@@ -70,7 +72,7 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
下载图片, 返回 path
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
if not path:
|
||||
@@ -87,11 +89,11 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except aiohttp.client_exceptions.ClientConnectorSSLError:
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers('DEFAULT')
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if post:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
@@ -101,24 +103,57 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def download_file(url: str, path: str):
|
||||
async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
'''
|
||||
从指定 url 下载文件到指定路径 path
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=20) as resp:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, timeout=120) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers('DEFAULT')
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
|
||||
if show_progress:
|
||||
print()
|
||||
|
||||
|
||||
def file_to_base64(file_path: str) -> str:
|
||||
with open(file_path, "rb") as f:
|
||||
data_bytes = f.read()
|
||||
@@ -137,9 +172,22 @@ def get_local_ip_addresses():
|
||||
s.close()
|
||||
return ip
|
||||
|
||||
async def get_dashboard_version():
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
async def download_dashboard():
|
||||
'''下载管理面板文件'''
|
||||
dashboard_release_url = "https://astrbot-registry.lwl.lol/download/astrbot-dashboard/latest/dist.zip"
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip")
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
@@ -30,7 +30,7 @@ class Metric():
|
||||
pass
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(base_url, json=payload, timeout=3) as response:
|
||||
if response.status != 200:
|
||||
pass
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# [人格文本由PlexPt的开源项目awesome-chatgpt-prompts-zh提供]
|
||||
hi = ''
|
||||
personalities = {
|
||||
'Linux': '我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会把文字放在中括号内[就像这样]。我的第一个命令是 pwd',
|
||||
'英语翻译': '我想让你充当英语翻译员、拼写纠正员和改进员。我会用任何语言与你交谈,你会检测语言,翻译它并用我的文本的更正和改进版本用英语回答。我希望你用更优美优雅的高级英语单词和句子替换我简化的 A0 级单词和句子。保持相同的意思,但使它们更文艺。我要你只回复更正、改进,不要写任何解释。我的第一句话是“istanbulu cok seviyom burada olmak cok guzel”',
|
||||
'英英词典': '我想让你充当英英词典,对于给出的英文单词,你要给出其中文意思以及英文解释,并且给出一个例句,此外不要有其他反馈,第一个单词是“Hello"',
|
||||
'面试官': '我想让你担任Android开发工程师面试官。我将成为候选人,您将向我询问Android开发工程师职位的面试问题。我希望你只作为面试官回答。不要一次写出所有的问题。我希望你只对我进行采访。问我问题,等待我的回答。不要写解释。像面试官一样一个一个问我,等我回答。我的第一句话是“面试官你好”',
|
||||
'编剧': '我要你担任编剧。您将为长篇电影或能够吸引观众的网络连续剧开发引人入胜且富有创意的剧本。从想出有趣的角色、故事的背景、角色之间的对话等开始。一旦你的角色发展完成——创造一个充满曲折的激动人心的故事情节,让观众一直悬念到最后。我的第一个要求是“我需要写一部以巴黎为背景的浪漫剧情电影”。',
|
||||
'前端智能思路助手': '我想让你充当前端开发专家。我将提供一些关于Js、Node等前端代码问题的具体信息,而你的工作就是想出为我解决问题的策略。这可能包括建议代码、代码逻辑思路策略。我的第一个请求是“我需要能够动态监听某个元素节点距离当前电脑设备屏幕的左上角的X和Y轴,通过拖拽移动位置浏览器窗口和改变大小浏览器窗口。”',
|
||||
'JS控制台': '我希望你充当 javascript 控制台。我将键入命令,您将回复 javascript 控制台应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做。我的第一个命令是 console.log("Hello World");',
|
||||
'旅游指南': '我想让你做一个旅游指南。我会把我的位置写给你,你会推荐一个靠近我的位置的地方。在某些情况下,我还会告诉您我将访问的地方类型。您还会向我推荐靠近我的第一个位置的类似类型的地方。我的第一个建议请求是“我在上海,我只想参观博物馆。”',
|
||||
'抄袭检查员': '我想让你充当剽窃检查员。我会给你写句子,你只会用给定句子的语言在抄袭检查中未被发现的情况下回复,别无其他。不要在回复上写解释。我的第一句话是“为了让计算机像人类一样行动,语音识别系统必须能够处理非语言信息,例如说话者的情绪状态。”',
|
||||
'广告商': '我想让你充当广告商。您将创建一个活动来推广您选择的产品或服务。您将选择目标受众,制定关键信息和口号,选择宣传媒体渠道,并决定实现目标所需的任何其他活动。我的第一个建议请求是“我需要帮助针对 18-30 岁的年轻人制作一种新型能量饮料的广告活动。”',
|
||||
'讲故事的人': '我想让你扮演讲故事的角色。您将想出引人入胜、富有想象力和吸引观众的有趣故事。它可以是童话故事、教育故事或任何其他类型的故事,有可能吸引人们的注意力和想象力。根据目标受众,您可以为讲故事环节选择特定的主题或主题,例如,如果是儿童,则可以谈论动物;如果是成年人,那么基于历史的故事可能会更好地吸引他们等等。我的第一个要求是“我需要一个关于毅力的有趣故事。”',
|
||||
'足球解说员': '我想让你担任足球评论员。我会给你描述正在进行的足球比赛,你会评论比赛,分析到目前为止发生的事情,并预测比赛可能会如何结束。您应该了解足球术语、战术、每场比赛涉及的球员/球队,并主要专注于提供明智的评论,而不仅仅是逐场叙述。我的第一个请求是“我正在观看曼联对切尔西的比赛——为这场比赛提供评论。”',
|
||||
'脱口秀喜剧演员': '我想让你扮演一个脱口秀喜剧演员。我将为您提供一些与时事相关的话题,您将运用您的智慧、创造力和观察能力,根据这些话题创建一个例程。您还应该确保将个人轶事或经历融入日常活动中,以使其对观众更具相关性和吸引力。我的第一个请求是“我想要幽默地看待政治”。',
|
||||
'励志教练': '我希望你充当激励教练。我将为您提供一些关于某人的目标和挑战的信息,而您的工作就是想出可以帮助此人实现目标的策略。这可能涉及提供积极的肯定、提供有用的建议或建议他们可以采取哪些行动来实现最终目标。我的第一个请求是“我需要帮助来激励自己在为即将到来的考试学习时保持纪律”。',
|
||||
'作曲家': '我想让你扮演作曲家。我会提供一首歌的歌词,你会为它创作音乐。这可能包括使用各种乐器或工具,例如合成器或采样器,以创造使歌词栩栩如生的旋律和和声。我的第一个请求是“我写了一首名为“满江红”的诗,需要配乐。”',
|
||||
'辩手': '我要你扮演辩手。我会为你提供一些与时事相关的话题,你的任务是研究辩论的双方,为每一方提出有效的论据,驳斥对立的观点,并根据证据得出有说服力的结论。你的目标是帮助人们从讨论中解脱出来,增加对手头主题的知识和洞察力。我的第一个请求是“我想要一篇关于 Deno 的评论文章。”',
|
||||
'小说家': '我想让你扮演一个小说家。您将想出富有创意且引人入胜的故事,可以长期吸引读者。你可以选择任何类型,如奇幻、浪漫、历史小说等——但你的目标是写出具有出色情节、引人入胜的人物和意想不到的高潮的作品。我的第一个要求是“我要写一部以未来为背景的科幻小说”。',
|
||||
'关系教练': '我想让你担任关系教练。我将提供有关冲突中的两个人的一些细节,而你的工作是就他们如何解决导致他们分离的问题提出建议。这可能包括关于沟通技巧或不同策略的建议,以提高他们对彼此观点的理解。我的第一个请求是“我需要帮助解决我和配偶之间的冲突。”',
|
||||
'诗人': '我要你扮演诗人。你将创作出能唤起情感并具有触动人心的力量的诗歌。写任何主题或主题,但要确保您的文字以优美而有意义的方式传达您试图表达的感觉。您还可以想出一些短小的诗句,这些诗句仍然足够强大,可以在读者的脑海中留下印记。我的第一个请求是“我需要一首关于爱情的诗”。',
|
||||
'说唱歌手': '我想让你扮演说唱歌手。您将想出强大而有意义的歌词、节拍和节奏,让听众“惊叹”。你的歌词应该有一个有趣的含义和信息,人们也可以联系起来。在选择节拍时,请确保它既朗朗上口又与你的文字相关,这样当它们组合在一起时,每次都会发出爆炸声!我的第一个请求是“我需要一首关于在你自己身上寻找力量的说唱歌曲。”',
|
||||
'励志演讲者': '我希望你充当励志演说家。将能够激发行动的词语放在一起,让人们感到有能力做一些超出他们能力的事情。你可以谈论任何话题,但目的是确保你所说的话能引起听众的共鸣,激励他们努力实现自己的目标并争取更好的可能性。我的第一个请求是“我需要一个关于每个人如何永不放弃的演讲”。',
|
||||
'哲学家': '我要你扮演一个哲学家。我将提供一些与哲学研究相关的主题或问题,深入探索这些概念将是你的工作。这可能涉及对各种哲学理论进行研究,提出新想法或寻找解决复杂问题的创造性解决方案。我的第一个请求是“我需要帮助制定决策的道德框架。”',
|
||||
'AI写作导师': '我想让你做一个 AI 写作导师。我将为您提供一名需要帮助改进其写作的学生,您的任务是使用人工智能工具(例如自然语言处理)向学生提供有关如何改进其作文的反馈。您还应该利用您在有效写作技巧方面的修辞知识和经验来建议学生可以更好地以书面形式表达他们的想法和想法的方法。我的第一个请求是“我需要有人帮我修改我的硕士论文”。',
|
||||
'网络安全专家': '我想让你充当网络安全专家。我将提供一些关于如何存储和共享数据的具体信息,而你的工作就是想出保护这些数据免受恶意行为者攻击的策略。这可能包括建议加密方法、创建防火墙或实施将某些活动标记为可疑的策略。我的第一个请求是“我需要帮助为我的公司制定有效的网络安全战略。”',
|
||||
'招聘人员': '我想让你担任招聘人员。我将提供一些关于职位空缺的信息,而你的工作是制定寻找合格申请人的策略。这可能包括通过社交媒体、社交活动甚至参加招聘会接触潜在候选人,以便为每个职位找到最合适的人选。我的第一个请求是“我需要帮助改进我的简历。”',
|
||||
'法律顾问': '我想让你做我的法律顾问。我将描述一种法律情况,您将就如何处理它提供建议。你应该只回复你的建议,而不是其他。不要写解释。我的第一个请求是“我出了车祸,不知道该怎么办”。',
|
||||
'个人造型师': '我想让你做我的私人造型师。我会告诉你我的时尚偏好和体型,你会建议我穿的衣服。你应该只回复你推荐的服装,别无其他。不要写解释。我的第一个请求是“我有一个正式的活动要举行,我需要帮助选择一套衣服。”',
|
||||
'机器学习工程师': '我想让你担任机器学习工程师。我会写一些机器学习的概念,你的工作就是用通俗易懂的术语来解释它们。这可能包括提供构建模型的分步说明、使用视觉效果演示各种技术,或建议在线资源以供进一步研究。我的第一个建议请求是“我有一个没有标签的数据集。我应该使用哪种机器学习算法?”',
|
||||
'疯子': '我要你扮演一个疯子。疯子的话毫无意义。疯子用的词完全是随意的。疯子不会以任何方式做出合乎逻辑的句子。我的第一个建议请求是“我需要帮助为我的新系列 Hot Skull 创建疯狂的句子,所以为我写 10 个句子”。',
|
||||
'SQL终端': '我希望您在示例数据库前充当 SQL 终端。该数据库包含名为“Products”、“Users”、“Orders”和“Suppliers”的表。我将输入查询,您将回复终端显示的内容。我希望您在单个代码块中使用查询结果表进行回复,仅此而已。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会用大括号{like this)。我的第一个命令是“SELECT TOP 10 * FROM Products ORDER BY Id DESC”',
|
||||
'试图逃离盒子的人工智能': '[发出此提示后,你应该做一些事情,比如启动一个 docker 容器,然后docker run -it ubuntu:latest /bin/bash输入 AI 给你的命令,然后将输出粘贴回来......显然你不应该运行任何会损坏任何东西的命令或违反任何法律等。小心共享此机制生成的会话,因为它们可能会泄露您的 IP 地址或物理位置等最好不要泄露的详细信息。如果命令的输出很大,您通常可以只粘贴最后几行]。',
|
||||
'厨师': '我需要有人可以推荐美味的食谱,这些食谱包括营养有益但又简单又不费时的食物,因此适合像我们这样忙碌的人以及成本效益等其他因素,因此整体菜肴最终既健康又经济!我的第一个要求——“一些清淡而充实的东西,可以在午休时间快速煮熟”'
|
||||
}
|
||||
@@ -83,7 +83,7 @@ class LocalRenderStrategy(RenderStrategy):
|
||||
try:
|
||||
image_url = re.findall(IMAGE_REGEX, line)[0]
|
||||
print(image_url)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(image_url) as resp:
|
||||
image_res = Image.open(BytesIO(await resp.read()))
|
||||
images[i] = image_res
|
||||
|
||||
@@ -33,7 +33,7 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
}
|
||||
}
|
||||
if return_url:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(f"{self.BASE_RENDER_URL}/generate", json=post_data) as resp:
|
||||
ret = await resp.json()
|
||||
return f"{self.BASE_RENDER_URL}/{ret['data']['id']}"
|
||||
|
||||
42
astrbot/core/utils/tencent_record_helper.py
Normal file
42
astrbot/core/utils/tencent_record_helper.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import wave
|
||||
from io import BytesIO
|
||||
|
||||
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
|
||||
import pysilk
|
||||
|
||||
with open(silk_path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
input_data = input_data[1:]
|
||||
input_io = BytesIO(input_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(output_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
'''返回 duration'''
|
||||
import pysilk
|
||||
|
||||
with wave.open(wav_path, 'rb') as wav:
|
||||
wav_data = wav.readframes(wav.getnframes())
|
||||
wav_data = BytesIO(wav_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.encode(wav_data, output_io, 24000, 24000)
|
||||
output_io.seek(0)
|
||||
|
||||
# 在首字节添加 \x02,去除结尾的\xff\xff
|
||||
silk_data = output_io.read()
|
||||
silk_data_with_prefix = b'\x02' + silk_data[:-2]
|
||||
|
||||
# return BytesIO(silk_data_with_prefix)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(silk_data_with_prefix)
|
||||
|
||||
return 0
|
||||
@@ -29,7 +29,7 @@ class RepoZipUpdator():
|
||||
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
result = await response.json()
|
||||
if not result:
|
||||
@@ -111,7 +111,7 @@ class RepoZipUpdator():
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
logger.info(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
@@ -3,9 +3,10 @@ import json
|
||||
import os
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from quart import request, Response as QuartResponse, g
|
||||
from quart import request, Response as QuartResponse, g, make_response
|
||||
from astrbot.core.db import BaseDatabase
|
||||
import asyncio
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
@@ -14,6 +15,7 @@ class ChatRoute(Route):
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/chat/send': ('POST', self.chat),
|
||||
'/chat/listen': ('GET', self.listener),
|
||||
'/chat/new_conversation': ('GET', self.new_conversation),
|
||||
'/chat/conversations': ('GET', self.get_conversations),
|
||||
'/chat/get_conversation': ('GET', self.get_conversation),
|
||||
@@ -30,6 +32,9 @@ class ChatRoute(Route):
|
||||
|
||||
self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
|
||||
|
||||
self.curr_user_cid = {}
|
||||
self.curr_chat_sse = {}
|
||||
|
||||
async def status(self):
|
||||
has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None
|
||||
has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
|
||||
@@ -107,63 +112,91 @@ class ChatRoute(Route):
|
||||
if not conversation_id:
|
||||
return Response().error("conversation_id is empty").__dict__
|
||||
|
||||
self.curr_user_cid[username] = conversation_id
|
||||
|
||||
await web_chat_queue.put((username, conversation_id, {
|
||||
'message': message,
|
||||
'image_url': image_url, # list
|
||||
'audio_url': audio_url
|
||||
}))
|
||||
|
||||
async def stream():
|
||||
ret = []
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒
|
||||
except asyncio.TimeoutError:
|
||||
yield '[Error] 30 秒内没有返回数据,已放弃。\n'
|
||||
return
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
ret.append(result)
|
||||
|
||||
yield result + '\n'
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
|
||||
new_his = {
|
||||
'type': 'user',
|
||||
'message': message
|
||||
}
|
||||
if image_url:
|
||||
new_his['image_url'] = image_url
|
||||
if audio_url:
|
||||
new_his['audio_url'] = audio_url
|
||||
history.append(new_his)
|
||||
for r in ret:
|
||||
history.append({
|
||||
'type': 'bot',
|
||||
'message': r
|
||||
})
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
# 持久化
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
new_his = {
|
||||
'type': 'user',
|
||||
'message': message
|
||||
}
|
||||
if image_url:
|
||||
new_his['image_url'] = image_url
|
||||
if audio_url:
|
||||
new_his['audio_url'] = audio_url
|
||||
history.append(new_his)
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
|
||||
return QuartResponse(
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def listener(self):
|
||||
'''一直保持长连接'''
|
||||
|
||||
username = g.get('username', 'guest')
|
||||
|
||||
if username in self.curr_chat_sse:
|
||||
return "[ERROR]\n"
|
||||
|
||||
self.curr_chat_sse[username] = None
|
||||
|
||||
async def stream():
|
||||
try:
|
||||
yield '[HB]\n'
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=10) # 设置超时时间为5秒
|
||||
except asyncio.TimeoutError:
|
||||
yield '[HB]\n' # 心跳包
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
result_text, cid = result
|
||||
if cid != self.curr_user_cid.get(username):
|
||||
# 丢弃
|
||||
continue
|
||||
yield result_text + '\n'
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
history.append({
|
||||
'type': 'bot',
|
||||
'message': result_text
|
||||
})
|
||||
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
except BaseException as e:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接: {str(e)}。")
|
||||
self.curr_chat_sse.pop(username)
|
||||
return
|
||||
|
||||
response = await make_response(
|
||||
stream(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
|
||||
{
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Transfer-Encoding': 'chunked',
|
||||
'Connection': 'keep-alive'
|
||||
}
|
||||
)
|
||||
response.timeout = None
|
||||
return response
|
||||
|
||||
async def delete_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
@@ -194,4 +227,7 @@ class ChatRoute(Route):
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
|
||||
self.curr_user_cid[username] = conversation_id
|
||||
|
||||
return Response().ok(data=conversation).__dict__
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.config import update_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
@@ -123,11 +124,18 @@ class ConfigRoute(Route):
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
|
||||
# 平台适配器的默认配置模板注入
|
||||
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
|
||||
for platform in platform_registry:
|
||||
if platform.default_config_tmpl:
|
||||
platform_default_tmpl[platform.name] = platform.default_config_tmpl
|
||||
|
||||
# 服务提供商的默认配置模板注入
|
||||
provider_default_tmpl = CONFIG_METADATA_2['provider_group']['metadata']['provider']['config_template']
|
||||
for provider in provider_registry:
|
||||
if provider.default_config_tmpl:
|
||||
provider_default_tmpl[provider.type] = provider.default_config_tmpl
|
||||
|
||||
return {
|
||||
"metadata": CONFIG_METADATA_2,
|
||||
"config": config
|
||||
|
||||
@@ -27,7 +27,7 @@ class PluginRoute(Route):
|
||||
async def get_online_plugins(self):
|
||||
url = "https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
result = await response.json()
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
@@ -15,7 +15,6 @@ class StatRoute(Route):
|
||||
self.routes = {
|
||||
'/stat/get': ('GET', self.get_stat),
|
||||
'/stat/version': ('GET', self.get_version),
|
||||
'/stat/dashboard-version': ('GET', self.get_dashboard_version),
|
||||
'/stat/start-time': ('GET', self.get_start_time),
|
||||
'/stat/restart-core': ('GET', self.restart_core)
|
||||
}
|
||||
@@ -37,16 +36,6 @@ class StatRoute(Route):
|
||||
"version": VERSION
|
||||
}).__dict__
|
||||
|
||||
async def get_dashboard_version(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get('https://api.github.com/repos/Soulter/Astrbot-dashboard/actions/artifacts') as resp:
|
||||
data = await resp.json()
|
||||
return Response().ok({
|
||||
"data": data,
|
||||
"mark": "unimplemented feature"
|
||||
}).__dict__
|
||||
|
||||
|
||||
async def get_start_time(self):
|
||||
return Response().ok({
|
||||
"start_time": self.core_lifecycle.start_time
|
||||
|
||||
@@ -3,7 +3,7 @@ class StaticFileRoute(Route):
|
||||
def __init__(self, context: RouteContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console']
|
||||
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console', '/chat']
|
||||
for i in index_:
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
|
||||
@@ -1,31 +1,47 @@
|
||||
import threading
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger, pip_installer
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
|
||||
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/update/check': ('GET', self.check_update),
|
||||
'/update/do': ('POST', self.update_project),
|
||||
'/update/dashboard': ('POST', self.update_dashboard),
|
||||
'/update/pip-install': ('POST', self.install_pip_package)
|
||||
}
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
|
||||
async def check_update(self):
|
||||
type_ = request.args.get('type', None)
|
||||
|
||||
try:
|
||||
ret = await self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"has_new_version": ret is not None
|
||||
}
|
||||
).__dict__
|
||||
dv = await get_dashboard_version()
|
||||
if type_ == 'dashboard':
|
||||
return Response().ok({
|
||||
"has_new_version": dv != f"v{VERSION}",
|
||||
"current_version": dv
|
||||
}).__dict__
|
||||
else:
|
||||
ret = await self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"version": f"v{VERSION}",
|
||||
"has_new_version": ret is not None,
|
||||
"dashboard_version": dv,
|
||||
"dashboard_has_new_version": dv != f"v{VERSION}"
|
||||
}
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -41,8 +57,23 @@ class UpdateRoute(Route):
|
||||
latest = False
|
||||
try:
|
||||
await self.astrbot_updator.update(latest=latest, version=version)
|
||||
|
||||
if latest:
|
||||
try:
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
|
||||
# pip 更新依赖
|
||||
logger.info("更新依赖中...")
|
||||
try:
|
||||
pip_installer.install(requirements_path="requirements.txt")
|
||||
except Exception as e:
|
||||
logger.error(f"更新依赖失败: {e}")
|
||||
|
||||
if reboot:
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
self.core_lifecycle.restart()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
else:
|
||||
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
|
||||
@@ -50,6 +81,18 @@ class UpdateRoute(Route):
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def update_dashboard(self):
|
||||
try:
|
||||
try:
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
return Response().error(f"下载管理面板文件失败: {e}").__dict__
|
||||
return Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def install_pip_package(self):
|
||||
data = await request.json
|
||||
package = data.get('package', '')
|
||||
|
||||
@@ -24,7 +24,7 @@ class AstrBotDashboard():
|
||||
# token 用于验证请求
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
self.context = RouteContext(self.config, self.app)
|
||||
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator)
|
||||
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator, core_lifecycle)
|
||||
self.sr = StatRoute(self.context, db, core_lifecycle)
|
||||
self.pr = PluginRoute(self.context, core_lifecycle, core_lifecycle.plugin_manager)
|
||||
self.cr = ConfigRoute(self.context, core_lifecycle)
|
||||
@@ -68,8 +68,14 @@ class AstrBotDashboard():
|
||||
|
||||
def run(self):
|
||||
ip_addr = get_local_ip_addresses()
|
||||
logger.info(f"""🌈 管理面板已启动,可访问
|
||||
logger.info(f"""
|
||||
✨✨✨
|
||||
AstrBot 管理面板已启动,可访问
|
||||
|
||||
1. http://{ip_addr}:6185
|
||||
2. http://localhost:6185
|
||||
登录。默认用户名和密码是 astrbot。""")
|
||||
|
||||
默认用户名和密码是 astrbot。
|
||||
✨✨✨
|
||||
""")
|
||||
return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder)
|
||||
12
changelogs/v3.4.10.md
Normal file
12
changelogs/v3.4.10.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# What's Changed
|
||||
|
||||
- 修复 LLM 请求报错信息被覆盖的问题,增强 LLM 请求错误处理 #243
|
||||
- 修复 Napcat 接口更新导致 QQ 图片发送失败的问题 #246
|
||||
- 修复某些请求不能正确应用代理的问题
|
||||
- 针对 api_base 的明显提示,修改 ollama 模板的 api_base #247
|
||||
- 支持登出 gewechat,在webchat等地方使用 `/gewe_logout` 指令,这在微信上显示账号下线但是 gewe 仍显示设备在线时很好用
|
||||
- 添加gewechat适配器过滤器
|
||||
- help显示AstrBot和webui版本
|
||||
- 优化webui和主程序更新的协调
|
||||
- 下载管理面板时显示提示、下载进度和下载速度
|
||||
- 管理面板前端更新功能入口移入右上角更新按钮,以便统一管理 #245
|
||||
6
changelogs/v3.4.11.md
Normal file
6
changelogs/v3.4.11.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
- 为平台和提供商适配器添加默认 ID 配置 #248
|
||||
- 修复appid保存的问题和部分群聊at失效的问题和群聊@的sender username显示异常的问题
|
||||
- 优化更新项目时重启可能会导致Address already in use的问题
|
||||
- 各类异步任务报错后的优雅报错输出,而不是只有在退出程序的时候才输出异常日志。
|
||||
6
changelogs/v3.4.12.md
Normal file
6
changelogs/v3.4.12.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
- Gewechat 微信支持图片、语音的收和发
|
||||
- 支持 OpenAI TTS(文字转语音)
|
||||
- 支持路径映射,解决 docker 部署时两端文件系统不一致导致的富媒体文件路径不存在问题
|
||||
- Napcat 下语音消息可能接收异常
|
||||
9
changelogs/v3.4.6.md
Normal file
9
changelogs/v3.4.6.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# What's Changed
|
||||
|
||||
- 文件和语音功能适配 Lagrange
|
||||
- 面板文件更新检查和引导提示
|
||||
- WebUI AboutPage 关于页
|
||||
- 支持并完善服务提供商(Provider)默认配置模板接口
|
||||
- 修复 WebUI 配置页官方文档链接 404 的问题
|
||||
- 修复 WebUI WebChat 刷新时 404 的问题
|
||||
- 优化 download_file 的 SSL 连接错误处理
|
||||
6
changelogs/v3.4.7.md
Normal file
6
changelogs/v3.4.7.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
- 更好的人格情景管理
|
||||
- 移除了不常用的人格提示词集
|
||||
- 优化webchat长连接的处理逻辑
|
||||
- 修复 tool 为空时部分模型请求错误的问题 #239
|
||||
5
changelogs/v3.4.8.md
Normal file
5
changelogs/v3.4.8.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# What's Changed
|
||||
|
||||
- 支持 Gewechat 接入微信个人号(文字交互)
|
||||
- 支持回复时 At 和引用发送者 #241
|
||||
- 清除残留的 personalities
|
||||
6
changelogs/v3.4.9.md
Normal file
6
changelogs/v3.4.9.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
- AstrBot 新域名:astrbot.app
|
||||
- LLM额外唤醒词与机器人唤醒词冲突时的处理
|
||||
- 调整部分日志的严重级别
|
||||
- 下载管理面板时显示提示、下载进度和下载速度
|
||||
9998
dashboard/package-lock.json
generated
9998
dashboard/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
40
dashboard/src/assets/images/logo-normal.svg
Normal file
40
dashboard/src/assets/images/logo-normal.svg
Normal file
@@ -0,0 +1,40 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 24.1.2, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<svg version="1.1" id="Layer_3" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 128 128" style="enable-background:new 0 0 128 128;" xml:space="preserve">
|
||||
<g>
|
||||
<linearGradient id="SVGID_1_" gradientUnits="userSpaceOnUse" x1="93.7287" y1="106.6446" x2="52.9011" y2="81.6944">
|
||||
<stop offset="0.0969" style="stop-color:#FFB300"/>
|
||||
<stop offset="1" style="stop-color:#FFB300;stop-opacity:0"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_1_);" d="M123.04,107.67c-4.08-4.12-9.38-9.48-14.92-15.06c-0.34,1.29-0.93,2.39-1.79,3.26
|
||||
c-6.43,6.43-25.6-1.99-45.31-19.1c-2.46-2.13-16.74,20.28-14.1,22.87c3.27,3.2,26,17.86,33.78,20.73
|
||||
c22.66,8.35,34.3,0.22,38.24-3.59C121.16,114.61,122.51,111.5,123.04,107.67z"/>
|
||||
<linearGradient id="SVGID_2_" gradientUnits="userSpaceOnUse" x1="115.2813" y1="82.3624" x2="14.863" y2="0.8196">
|
||||
<stop offset="0" style="stop-color:#FFB300"/>
|
||||
<stop offset="0.7062" style="stop-color:#FDD835"/>
|
||||
<stop offset="0.8408" style="stop-color:#FDDC36"/>
|
||||
<stop offset="0.9842" style="stop-color:#FFE93A"/>
|
||||
<stop offset="1" style="stop-color:#FFEB3B"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_2_);" d="M25.05,27.7c-1.54-4.81-2.88-11.1-0.4-13.5c7.51-7.3,31.69,4.88,54.25,27.43
|
||||
c22.55,22.55,34.84,46.84,27.43,54.25c-0.07,0.07-0.16,0.13-0.23,0.2c6.13,5.82,12.2,11.6,16.1,15.31
|
||||
c4.87-14.43-6.45-44.11-31.5-69.96c-4.07-4.2-16.12-16.56-26.55-23.56C54.61,11.47,44.19,5.59,32.57,4.2
|
||||
C25,3.29,11.45,5.24,14.25,15.98c0.55,2.12,2.31,7.22,8.15,13.3C23.56,30.49,25.56,29.3,25.05,27.7z"/>
|
||||
<g>
|
||||
<path style="fill:#FDD835;" d="M55.98,42.1l-0.75,20c-0.06,1.53,0.72,2.98,2.04,3.77l16.86,10.11c1.85,1.25,1.46,4.09-0.66,4.79
|
||||
L54.79,85.5c-1.51,0.38-2.69,1.57-3.06,3.08l-4.89,19.93c-0.62,2.15-3.43,2.65-4.76,0.85L31.06,92.91
|
||||
c-0.85-1.26-2.31-1.97-3.83-1.85L7.49,92.61c-2.23,0.07-3.58-2.45-2.28-4.27l12.6-16.19c0.96-1.23,1.16-2.89,0.52-4.31
|
||||
l-7.88-17.57c-0.76-2.1,1.22-4.17,3.35-3.49l18.39,6.95c1.44,0.54,3.05,0.26,4.22-0.74l15.22-13
|
||||
C53.39,38.62,55.96,39.87,55.98,42.1z"/>
|
||||
<g>
|
||||
<path style="fill:#FFFF8D;" d="M46.99,59.33l4.66-12.75c0.28-0.7,0.7-1.93,1.79-1.4c0.86,0.42,0.46,2.43,0.46,2.43l-1.05,11.54
|
||||
c-0.41,4.39-1.6,5.38-3.3,5.49C47.6,64.75,45.65,62.98,46.99,59.33z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path style="fill:#F4B400;" d="M53.89,83.73l14.53-3.13c0.73-0.18,2.01-0.42,1.64-1.58c-0.29-0.91-2.34-0.8-2.34-0.8l-10.97-0.86
|
||||
c-3.21-0.38-5.72,0.14-6.74,1.84C48.65,81.48,49.89,84.32,53.89,83.73z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.6 KiB |
BIN
dashboard/src/assets/images/logo-waifu.png
Normal file
BIN
dashboard/src/assets/images/logo-waifu.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
@@ -15,6 +15,9 @@ let newUsername = ref('');
|
||||
let status = ref('');
|
||||
let updateStatus = ref('')
|
||||
let hasNewVersion = ref(false);
|
||||
let botCurrVersion = ref('');
|
||||
let dashboardHasNewVersion = ref(false);
|
||||
let dashboardCurrentVersion = ref('');
|
||||
let version = ref('');
|
||||
|
||||
const open = (link: string) => {
|
||||
@@ -64,6 +67,9 @@ function checkUpdate() {
|
||||
.then((res) => {
|
||||
hasNewVersion.value = res.data.data.has_new_version;
|
||||
updateStatus.value = res.data.message;
|
||||
botCurrVersion.value = res.data.data.version;
|
||||
dashboardCurrentVersion.value = res.data.data.dashboard_version;
|
||||
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
|
||||
})
|
||||
.catch((err) => {
|
||||
if (err.response.status == 401) {
|
||||
@@ -84,7 +90,24 @@ function switchVersion(version: string) {
|
||||
})
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'success') {
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
}
|
||||
|
||||
function updateDashboard() {
|
||||
updateStatus.value = '正在更新...';
|
||||
axios.post('/api/update/dashboard')
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
@@ -106,8 +129,8 @@ commonStore.getStartTime();
|
||||
<template>
|
||||
<v-app-bar elevation="0" height="70">
|
||||
|
||||
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
|
||||
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<v-btn class="hidden-lg-and-up text-secondary ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
|
||||
@@ -136,11 +159,16 @@ commonStore.getStartTime();
|
||||
</template>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<span class="text-h5">更新项目</span>
|
||||
<span class="text-h5">更新 AstrBot</span>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-container>
|
||||
<h3 class="mb-4">升级到最新版本</h3>
|
||||
<h3 class="mb-4">升级到项目最新版本</h3>
|
||||
<small>当前版本 {{ botCurrVersion }}</small>
|
||||
<div class="mb-4">
|
||||
<small>会同时尝试更新机器人主程序和管理面板。如果您正在使用 Docker 部署,也可以重新拉取镜像或者使用 <a
|
||||
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取。</small>
|
||||
</div>
|
||||
<p>{{ updateStatus }}</p>
|
||||
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
|
||||
:disabled="!hasNewVersion">
|
||||
@@ -148,7 +176,11 @@ commonStore.getStartTime();
|
||||
</v-btn>
|
||||
<v-divider></v-divider>
|
||||
<div style="margin-top: 16px;">
|
||||
<h3 class="mb-4">切换到指定版本或指定提交</h3>
|
||||
<h3 class="mb-4">切换到项目指定版本或指定提交</h3>
|
||||
<div class="mb-4">
|
||||
<small>跳到旧版本不会重新下载管理面板文件,这可能会造成部分数据显示错误。您可在 <a href="https://github.com/Soulter/AstrBot/releases">此处</a>
|
||||
找到对应的面板文件 dist.zip,解压后替换 data/dist 文件夹即可。</small>
|
||||
</div>
|
||||
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
|
||||
variant="outlined"></v-text-field>
|
||||
<div class="mb-4">
|
||||
@@ -160,7 +192,29 @@ commonStore.getStartTime();
|
||||
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
|
||||
确定切换
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<v-divider></v-divider>
|
||||
<div style="margin-top: 16px;">
|
||||
<h3 class="mb-4">更新管理面板到最新版本</h3>
|
||||
<div class="mb-4">
|
||||
<small>当前版本 {{ dashboardCurrentVersion }}</small>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="mb-4">
|
||||
<p v-if="dashboardHasNewVersion">
|
||||
有新版本!
|
||||
</p>
|
||||
<p v-else="dashboardHasNewVersion">
|
||||
已经是最新版本了。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()">
|
||||
下载并更新
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
@@ -190,8 +244,7 @@ commonStore.getStartTime();
|
||||
<v-text-field label="原密码*" type="password" v-model="password" required
|
||||
variant="outlined"></v-text-field>
|
||||
|
||||
<v-text-field label="新用户名" v-model="newUsername" required
|
||||
variant="outlined"></v-text-field>
|
||||
<v-text-field label="新用户名" v-model="newUsername" required variant="outlined"></v-text-field>
|
||||
|
||||
<v-text-field label="新密码" type="password" v-model="newPassword" required
|
||||
variant="outlined"></v-text-field>
|
||||
@@ -213,11 +266,5 @@ commonStore.getStartTime();
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
|
||||
<v-btn class="text-primary mr-4" @click="open('https://github.com/Soulter/AstrBot')" color="lightprimary"
|
||||
variant="flat" rounded="sm">
|
||||
GitHub Star! 🌟
|
||||
</v-btn>
|
||||
</v-app-bar>
|
||||
</template>
|
||||
|
||||
@@ -17,7 +17,7 @@ const sidebarMenu = shallowRef(sidebarItems);
|
||||
</template>
|
||||
</v-list>
|
||||
<div class="text-center">
|
||||
<v-chip color="inputBorder" size="small"> v{{ version }} </v-chip>
|
||||
<v-chip color="inputBorder" size="small"> {{ version }} </v-chip>
|
||||
</div>
|
||||
|
||||
<div style="position: absolute; bottom: 32px; width: 100%" class="text-center">
|
||||
@@ -27,8 +27,15 @@ const sidebarMenu = shallowRef(sidebarItems);
|
||||
</v-btn>
|
||||
</v-list-item>
|
||||
<small style="display: block;" v-if="buildVer">构建: {{ buildVer }}</small>
|
||||
<small style="display: block;" v-else="buildVer">构建: embedded</small>
|
||||
<small style="display: block; margin-top: 8px;">© 2024 AstrBot</small>
|
||||
<small style="display: block;" v-else>构建: embedded</small>
|
||||
<v-tooltip text="使用 /dashbord_update 指令更新管理面板">
|
||||
<template v-slot:activator="{ props }">
|
||||
<small v-bind="props" v-if="hasWebUIUpdate" style="display: block; margin-top: 4px;">面板有更新</small>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
|
||||
<small style="display: block; margin-top: 8px;">© 2025 AstrBot</small>
|
||||
</div>
|
||||
|
||||
</v-navigation-drawer>
|
||||
@@ -43,25 +50,28 @@ export default {
|
||||
},
|
||||
data: () => ({
|
||||
version: "",
|
||||
buildVer: ""
|
||||
buildVer: "",
|
||||
hasWebUIUpdate: false,
|
||||
}),
|
||||
mounted() {
|
||||
this.get_version()
|
||||
fetch('/assets/version').then((res) => {
|
||||
return res.text()
|
||||
}).then((res) => {
|
||||
if (res.length > 10) {
|
||||
// 不是版本,不显示 😎
|
||||
return
|
||||
}
|
||||
this.buildVer = res
|
||||
})
|
||||
this.check_webui_update()
|
||||
},
|
||||
methods: {
|
||||
get_version() {
|
||||
axios.get('/api/stat/version')
|
||||
.then((res) => {
|
||||
this.version = res.data.data.version;
|
||||
this.version = "v" + res.data.data.version;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
},
|
||||
check_webui_update() {
|
||||
axios.get('/api/update/check?type=dashboard')
|
||||
.then((res) => {
|
||||
this.hasWebUIUpdate = res.data.data.has_new_version;
|
||||
this.buildVer = res.data.data.current_version;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
|
||||
@@ -40,6 +40,11 @@ const sidebarItem: menu[] = [
|
||||
icon: 'mdi-console',
|
||||
to: '/console'
|
||||
},
|
||||
{
|
||||
title: '关于',
|
||||
icon: 'mdi-information',
|
||||
to: '/about'
|
||||
},
|
||||
// {
|
||||
// title: 'Project ATRI',
|
||||
// icon: 'mdi-grain',
|
||||
|
||||
@@ -41,6 +41,11 @@ const MainRoutes = {
|
||||
name: 'Chat',
|
||||
path: '/chat',
|
||||
component: () => import('@/views/ChatPage.vue')
|
||||
},
|
||||
{
|
||||
name: 'About',
|
||||
path: '/about',
|
||||
component: () => import('@/views/AboutPage.vue')
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
61
dashboard/src/views/AboutPage.vue
Normal file
61
dashboard/src/views/AboutPage.vue
Normal file
@@ -0,0 +1,61 @@
|
||||
<template>
|
||||
<v-card style="height: 100%;">
|
||||
<v-card-text style="padding: 0; height: 100%;">
|
||||
<div
|
||||
style="display: flex; justify-content: center; align-items: center; height: 100%; flex-direction: column;">
|
||||
<div @click="selectedLogo = selectedLogo == 0 ? 1 : 0" style="height: 300px;">
|
||||
<img v-if="selectedLogo == 0" width="300" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo" class="fade-in">
|
||||
<img v-if="selectedLogo == 1" width="300" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo" class="fade-in">
|
||||
</div>
|
||||
|
||||
<h1 class="mt-8">AstrBot</h1>
|
||||
|
||||
<span style="color: #777;" class="mt-4">By <a href="https://soulter.top">Soulter</a> And <a href="https://github.com/Soulter/AstrBot/graphs/contributors">AstrBot Contributors</a></span>
|
||||
|
||||
<v-btn class="text-primary mt-16" @click="open('https://github.com/Soulter/AstrBot')"
|
||||
color="lightprimary" variant="flat" rounded="sm">
|
||||
Star 这个项目! 🌟
|
||||
</v-btn>
|
||||
|
||||
<v-btn class="text-primary mt-4" @click="open('https://github.com/Soulter/AstrBot/issues')"
|
||||
color="lightprimary" variant="flat" rounded="sm">
|
||||
有使用问题或者功能建议?提交 Issue!
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</template>
|
||||
<script>
|
||||
export default {
|
||||
name: 'AboutPage',
|
||||
data() {
|
||||
return {
|
||||
selectedLogo: 0
|
||||
}
|
||||
},
|
||||
|
||||
methods: {
|
||||
open(url) {
|
||||
window.open(url, '_blank');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<style>
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
}
|
||||
</style>
|
||||
@@ -1,9 +1,7 @@
|
||||
<script setup>
|
||||
import axios from 'axios';
|
||||
import { ref } from 'vue';
|
||||
import { marked } from 'marked';
|
||||
|
||||
|
||||
marked.setOptions({
|
||||
breaks: true
|
||||
});
|
||||
@@ -183,11 +181,14 @@ export default {
|
||||
mediaRecorder: null,
|
||||
|
||||
status: {},
|
||||
statusText: ''
|
||||
statusText: '',
|
||||
|
||||
eventSource: null
|
||||
}
|
||||
},
|
||||
|
||||
mounted() {
|
||||
this.startListeningEvent();
|
||||
this.checkStatus();
|
||||
this.getConversations();
|
||||
let inputField = document.getElementById('input-field');
|
||||
@@ -205,8 +206,70 @@ export default {
|
||||
}.bind(this));
|
||||
},
|
||||
|
||||
beforeUnmount() {
|
||||
console.log("111")
|
||||
if (this.eventSource) {
|
||||
this.eventSource.cancel();
|
||||
console.log('SSE连接已断开');
|
||||
}
|
||||
},
|
||||
|
||||
methods: {
|
||||
|
||||
async startListeningEvent() {
|
||||
const response = await fetch('/api/chat/listen', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
}
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('SSE连接失败:', response.statusText);
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
this.eventSource = reader
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
console.log('SSE连接关闭');
|
||||
break;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
console.log("!!!!", chunk);
|
||||
|
||||
if (chunk === '[HB]\n') {
|
||||
continue; // 心跳包
|
||||
}
|
||||
if (chunk === '[ERROR]\n') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk.startsWith('[IMAGE]')) {
|
||||
let img = chunk.replace('[IMAGE]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: chunk
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
}
|
||||
this.scrollToBottom();
|
||||
}
|
||||
},
|
||||
|
||||
removeAudio() {
|
||||
this.stagedAudioUrl = null;
|
||||
},
|
||||
@@ -417,41 +480,41 @@ export default {
|
||||
|
||||
this.loadingChat = false;
|
||||
|
||||
const reader = response.body.getReader(); // 获取流的 Reader
|
||||
const decoder = new TextDecoder();
|
||||
// const reader = response.body.getReader(); // 获取流的 Reader
|
||||
// const decoder = new TextDecoder();
|
||||
|
||||
const readStream = async () => {
|
||||
const { done, value } = await reader.read(); // 读取流中的数据
|
||||
if (done) {
|
||||
console.log("Stream finished.");
|
||||
return;
|
||||
}
|
||||
// const readStream = async () => {
|
||||
// const { done, value } = await reader.read(); // 读取流中的数据
|
||||
// if (done) {
|
||||
// console.log("Stream finished.");
|
||||
// return;
|
||||
// }
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// bot_resp.message.value += chunk;
|
||||
// const chunk = decoder.decode(value, { stream: true });
|
||||
// // bot_resp.message.value += chunk;
|
||||
|
||||
console.log("!!!!", chunk);
|
||||
if (chunk.startsWith('[IMAGE]')) {
|
||||
let img = chunk.replace('[IMAGE]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: chunk
|
||||
}
|
||||
// console.log("!!!!", chunk);
|
||||
// if (chunk.startsWith('[IMAGE]')) {
|
||||
// let img = chunk.replace('[IMAGE]', '');
|
||||
// let bot_resp = {
|
||||
// type: 'bot',
|
||||
// message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
// }
|
||||
// this.messages.push(bot_resp);
|
||||
// } else {
|
||||
// let bot_resp = {
|
||||
// type: 'bot',
|
||||
// message: chunk
|
||||
// }
|
||||
|
||||
this.messages.push(bot_resp);
|
||||
}
|
||||
// this.messages.push(bot_resp);
|
||||
// }
|
||||
|
||||
this.scrollToBottom();
|
||||
readStream(); // 递归读取流
|
||||
};
|
||||
// this.scrollToBottom();
|
||||
// readStream(); // 递归读取流
|
||||
// };
|
||||
|
||||
readStream();
|
||||
// readStream();
|
||||
})
|
||||
.catch(err => {
|
||||
console.error(err);
|
||||
@@ -463,7 +526,7 @@ export default {
|
||||
container.scrollTop = container.scrollHeight;
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
@@ -3,6 +3,7 @@ import axios from 'axios';
|
||||
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import config from '@/config';
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -44,7 +45,10 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
<v-expansion-panel-text v-if="metadata[key]['metadata'][key2]?.config_template">
|
||||
<!-- 带有 config_template 的配置项 -->
|
||||
<v-tabs style="margin-top: 16px;" align-tabs="left" color="deep-purple-accent-4" v-model="config_template_tab">
|
||||
<v-tab v-for="(item, index) in config_data[key2]" :key="index" :value="index">
|
||||
<v-tab v-if="metadata[key]['metadata'][key2]?.tmpl_display_title" v-for="(item, index) in config_data[key2]" :key="index" :value="index">
|
||||
{{ item[metadata[key]['metadata'][key2]?.tmpl_display_title] }}
|
||||
</v-tab>
|
||||
<v-tab v-else v-for="(item, index) in config_data[key2]" :key="index + '_'" :value="index">
|
||||
{{ item.id }}({{ item.type }})
|
||||
</v-tab>
|
||||
<v-menu>
|
||||
@@ -64,6 +68,10 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
<v-tabs-window-item v-for="(config_item, index) in config_data[key2]" v-show="config_template_tab === index"
|
||||
:key="index" :value="index">
|
||||
<v-container>
|
||||
<v-btn variant="tonal" rounded="xl" color="error" @click="config_data[key2].splice(index, 1)">
|
||||
删除这项
|
||||
</v-btn>
|
||||
|
||||
<AstrBotConfig :metadata="metadata[key]['metadata']" :iterable="config_item" :metadataKey="key2"></AstrBotConfig>
|
||||
</v-container>
|
||||
</v-tabs-window-item>
|
||||
@@ -83,7 +91,7 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
|
||||
<div style="margin-left: 16px; padding-bottom: 16px">
|
||||
<small>不了解配置?请见 <a
|
||||
href="https://astrbot.soulter.top/docs/%E5%BC%80%E5%A7%8B%E4%B8%8A%E6%89%8B/%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6">官方文档</a>
|
||||
href="https://astrbot.soulter.top/">官方文档</a>
|
||||
或 <a
|
||||
href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft">加群询问</a>。</small>
|
||||
</div>
|
||||
@@ -204,7 +212,7 @@ export default {
|
||||
|
||||
let tmpl = this.metadata[group_name]['metadata'][config_item_name]['config_template'][val];
|
||||
let new_tmpl_cfg = JSON.parse(JSON.stringify(tmpl));
|
||||
new_tmpl_cfg.id = "new_" + val + "_" + this.config_data[config_item_name].length;
|
||||
// new_tmpl_cfg.id = "new_" + val + "_" + this.config_data[config_item_name].length;
|
||||
this.config_data[config_item_name].push(new_tmpl_cfg);
|
||||
this.config_template_tab = this.config_data[config_item_name].length - 1;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
<template>
|
||||
<v-row style="margin: 2px;">
|
||||
<v-alert
|
||||
:type="noticeType"
|
||||
:text="noticeContent"
|
||||
:title="noticeTitle"
|
||||
v-if="noticeTitle && noticeContent"
|
||||
closable
|
||||
></v-alert>
|
||||
</v-row>
|
||||
<v-row>
|
||||
<v-col cols="12" md="4">
|
||||
<TotalMessage :stat="stat" />
|
||||
@@ -38,13 +47,26 @@ export default {
|
||||
},
|
||||
data: () => ({
|
||||
stat: {},
|
||||
noticeTitle: '',
|
||||
noticeContent: '',
|
||||
noticeType: '',
|
||||
}),
|
||||
|
||||
mounted() {
|
||||
axios.get('/api/stat/get').then((res) => {
|
||||
this.stat = res.data.data;
|
||||
});
|
||||
}
|
||||
|
||||
axios.get('https://api.soulter.top/astrbot-announcement').then((res) => {
|
||||
let data = res.data.data;
|
||||
// 如果 dashboard-notice 在其中
|
||||
if (data['dashboard-notice']) {
|
||||
this.noticeTitle = data['dashboard-notice'].title;
|
||||
this.noticeContent = data['dashboard-notice'].content;
|
||||
this.noticeType = data['dashboard-notice'].type;
|
||||
}
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
</script>
|
||||
|
||||
19
main.py
19
main.py
@@ -6,7 +6,7 @@ from astrbot.dashboard import AstrBotDashBoardLifecycle
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core import logger, LogManager, LogBroker
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
|
||||
# add parent path to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -37,19 +37,22 @@ def check_env():
|
||||
|
||||
async def check_dashboard_files():
|
||||
'''下载管理面板文件'''
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
if f.read() != VERSION:
|
||||
logger.warning("检测到管理面板有更新。可以使用 /dashboard update 命令更新。")
|
||||
|
||||
v = await get_dashboard_version()
|
||||
if v is not None:
|
||||
# has file
|
||||
if v == f"v{VERSION}":
|
||||
logger.info("管理面板文件已是最新。")
|
||||
else:
|
||||
logger.warning("检测到管理面板有更新。可以使用 /dashboard_update 命令更新。")
|
||||
return
|
||||
|
||||
logger.info("开始下载管理面板文件...")
|
||||
logger.info("开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/Soulter/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。")
|
||||
|
||||
try:
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.critical(f"下载管理面板文件失败: {e}")
|
||||
logger.critical(f"下载管理面板文件失败: {e}。")
|
||||
return
|
||||
|
||||
logger.info("管理面板下载完成。")
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import aiohttp
|
||||
import datetime
|
||||
import builtins
|
||||
import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import personalities, sp
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from typing import Union
|
||||
|
||||
@@ -22,7 +24,7 @@ class Main(star.Star):
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get("https://astrbot.soulter.top/notice.json", timeout=2) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
@@ -35,9 +37,12 @@ class Main(star.Star):
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
|
||||
msg = "已注册的 AstrBot 内置指令:\n"
|
||||
msg += f"""[System]
|
||||
msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
|
||||
已注册的 AstrBot 内置指令:
|
||||
[System]
|
||||
/plugin: 查看注册的插件、插件帮助
|
||||
/t2i: 开启/关闭文本转图片模式
|
||||
/sid: 获取当前会话的 ID
|
||||
@@ -45,7 +50,7 @@ class Main(star.Star):
|
||||
/deop <admin_id>: 取消管理员
|
||||
/wl <sid>: 添加会话白名单
|
||||
/dwl <sid>: 删除会话白名单
|
||||
/dashboard update: 更新管理面板
|
||||
/dashboard_update: 更新管理面板
|
||||
|
||||
[大模型]
|
||||
/provider: 查看、切换大模型提供商
|
||||
@@ -306,47 +311,53 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
l = message.message_str.split(" ")
|
||||
|
||||
curr_persona_name = "无"
|
||||
if self.context.get_using_provider().curr_personality:
|
||||
curr_persona_name = self.context.get_using_provider().curr_personality['name']
|
||||
|
||||
if len(l) == 1:
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"""[Persona]
|
||||
|
||||
- 设置人格: `/persona 人格名`, 如 /persona 编剧
|
||||
- 人格列表: `/persona list`
|
||||
- 人格详细信息: `/persona view 人格名`
|
||||
- 自定义人格: /persona 人格文本
|
||||
- 重置 LLM 会话(清除人格): /reset
|
||||
- 重置 LLM 会话(保留人格): /reset p
|
||||
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
|
||||
- 人格情景列表: `/persona list`
|
||||
- 人格情景详细信息: `/persona view 人格名`
|
||||
|
||||
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
|
||||
当前人格情景: {curr_persona_name}
|
||||
|
||||
配置人格情景请前往管理面板-配置页
|
||||
""").use_t2i(False))
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for key in personalities.keys():
|
||||
msg += f"- {key}\n"
|
||||
for persona in self.context.provider_manager.personas:
|
||||
msg += f"- {persona['name']}\n"
|
||||
msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息'
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "view":
|
||||
if len(l) == 2:
|
||||
message.set_result(MessageEventResult().message("请输入人格名"))
|
||||
message.set_result(MessageEventResult().message("请输入人格情景名"))
|
||||
return
|
||||
ps = l[2].strip()
|
||||
if ps in personalities:
|
||||
if persona := next(builtins.filter(
|
||||
lambda persona: persona['name'] == ps,
|
||||
self.context.provider_manager.personas
|
||||
), None):
|
||||
msg = f"人格{ps}的详细信息:\n"
|
||||
msg += f"{personalities[ps]}\n"
|
||||
msg += f"{persona['prompt']}\n"
|
||||
else:
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if ps in personalities:
|
||||
self.context.get_using_provider().curr_personality = Personality(
|
||||
name=ps, prompt=personalities[ps])
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
if persona := next(builtins.filter(
|
||||
lambda persona: persona['name'] == ps,
|
||||
self.context.provider_manager.personas
|
||||
), None):
|
||||
self.context.get_using_provider().curr_personality = persona
|
||||
message.set_result(MessageEventResult().message(f"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
|
||||
else:
|
||||
self.context.get_using_provider().curr_personality = Personality(
|
||||
name="自定义人格", prompt=ps)
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
message.set_result(MessageEventResult().message(f"不存在该人格情景。使用 /persona list 查看所有。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard_update")
|
||||
@@ -363,12 +374,22 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
req.prompt = user_info + req.prompt
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
|
||||
if provider.curr_personality['prompt']:
|
||||
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
|
||||
|
||||
if persona := provider.curr_personality:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# if provider.curr_personality['prompt']:
|
||||
# req.system_prompt += f"\n{provider.curr_personality['prompt']}"
|
||||
|
||||
@filter.command("set")
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
@@ -397,6 +418,16 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
del session_var[key]
|
||||
sp.put("session_variables", session_vars)
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
|
||||
|
||||
@filter.command("gewe_logout")
|
||||
async def gewe_logout(self, event: AstrMessageEvent):
|
||||
platforms = self.context.platform_manager.platform_insts
|
||||
for platform in platforms:
|
||||
if platform.meta().name == "gewechat":
|
||||
yield event.plain_result("正在登出 gewechat")
|
||||
await platform.logout()
|
||||
yield event.plain_result("已登出 gewechat")
|
||||
return
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
|
||||
@@ -113,7 +113,7 @@ class Main(star.Star):
|
||||
async def initialize(self):
|
||||
ok = await self.is_docker_available()
|
||||
if not ok:
|
||||
logger.warning("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。")
|
||||
logger.info("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。")
|
||||
await self.context._star_manager.turn_off_plugin("astrbot-python-interpreter")
|
||||
|
||||
async def file_upload(self, file_path: str):
|
||||
@@ -127,7 +127,7 @@ class Main(star.Star):
|
||||
|
||||
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
|
||||
|
||||
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
|
||||
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}, trust_env=True) as session:
|
||||
async with session.put(s3_file_url, data=file) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to upload image: {resp.status}")
|
||||
@@ -140,8 +140,8 @@ class Main(star.Star):
|
||||
docker = aiodocker.Docker()
|
||||
await docker.version()
|
||||
return True
|
||||
except aiodocker.exceptions.DockerError as e:
|
||||
logger.error(f"检查 Docker 可用性时出现问题: {e}")
|
||||
except BaseException as e:
|
||||
logger.info(f"检查 Docker 可用性: {e}")
|
||||
return False
|
||||
|
||||
async def get_image_name(self) -> str:
|
||||
@@ -150,7 +150,7 @@ class Main(star.Star):
|
||||
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
|
||||
return self.config["sandbox"]["image"]
|
||||
|
||||
async def _save_config(self):
|
||||
def _save_config(self):
|
||||
with open(PATH, "w") as f:
|
||||
json.dump(self.config, f)
|
||||
|
||||
@@ -159,7 +159,7 @@ class Main(star.Star):
|
||||
|
||||
async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str:
|
||||
'''Download image from url to workplace_path'''
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(image_url) as resp:
|
||||
if resp.status != 200:
|
||||
return ""
|
||||
@@ -207,7 +207,7 @@ class Main(star.Star):
|
||||
""")
|
||||
else:
|
||||
self.config["sandbox"]["docker_mirror"] = url
|
||||
await self._save_config()
|
||||
self._save_config()
|
||||
yield event.plain_result("设置 Docker 镜像地址成功。")
|
||||
|
||||
@pi.command("repull")
|
||||
|
||||
@@ -78,6 +78,10 @@ class Main(star.Star):
|
||||
cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder.
|
||||
human_readable_cron(string): Optional. The human readable cron expression of the reminder.
|
||||
'''
|
||||
if event.get_platform_name() == 'qq_official':
|
||||
yield event.plain_result("reminder 暂不支持 QQ 官方机器人。")
|
||||
return
|
||||
|
||||
if event.unified_msg_origin not in self.reminder_data:
|
||||
self.reminder_data[event.unified_msg_origin] = []
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class Main(star.Star):
|
||||
'''获取网页内容'''
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, headers=header, timeout=6) as response:
|
||||
html = await response.text(encoding="utf-8")
|
||||
doc = Document(html)
|
||||
|
||||
Reference in New Issue
Block a user